From 0a94bb432ed75cc2d950d81b2921363218a7e459 Mon Sep 17 00:00:00 2001 From: Xiaodong Wang Date: Fri, 3 Jan 2025 22:01:36 +0000 Subject: [PATCH] [ROCm] CK Flash Attention Backend (#143695) Replace https://github.com/pytorch/pytorch/pull/138947 for re-import. Replaces https://github.com/ROCm/pytorch/pull/1592 This PR contains the initial implementation of SDPA with composable_kernel backend. The CK path can be forced by simply calling torch.backends.cuda.preferred_rocm_fa_library("ck"). Similarly, you can force the incumbent aotriton implementation by passing in "aotriton" or "default". As you'd expect, not setting this option will result in aotriton to be used as the backend. In the case of CK, if pytorch deems flash attention usable, then it will use the CK path in all the same places aotriton would have been used. This PR makes no changes to the heuristics which select which attention scheme to use (i.e. flash attention vs memory efficient attention vs math etc etc). It only gets called when flash attention is both enabled (via USE_FLASH_ATTENTION) and is selected at runtime by the existing heuristics. Files located in pytorch/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha* have been pulled from https://github.com/Dao-AILab/flash-attention courtesy of @tridao's hard work who is the co-author NOTE: In order to use this backend, the user MUST set USE_CK_FLASH_ATTENTION=1 in their environment when they build PyTorch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143695 Approved by: https://github.com/malfet Co-authored-by: Andy Lugo Co-authored-by: Jithun Nair --- LICENSE | 4 + aten/src/ATen/CMakeLists.txt | 27 +- aten/src/ATen/Context.cpp | 34 + aten/src/ATen/Context.h | 8 + aten/src/ATen/ROCmFABackend.h | 31 + .../native/transformers/cuda/sdp_utils.cpp | 28 +- .../transformers/hip/aotriton_adapter.h | 2 +- .../{flash_api.hip => aot/mha_all_aot.hip} | 39 +- .../transformers/hip/flash_attn/ck/bias.hpp | 100 + .../hip/flash_attn/ck/fmha_bwd.hpp | 447 + ...042c36bc588e60a7c8a9ba297a8a25d8ac0660.hip | 144 + ...29076f83a3dc695a167beda6fe19230a2b114b.hip | 144 + ...6c417a52a1bd7c55e45d111483d26f4480caeb.hip | 144 + ...8f2429c678d13386a06e8d8b15c4b480940ff3.hip | 79 + ...a2adbe938d458d51ca5fc4020667a215b672a4.hip | 144 + ...2c0f480917c329f4c3c6c666cf32af2d82b294.hip | 84 + ...4c209d5cfc6b965bfd78c64bf132c0154e32be.hip | 144 + ...53ec18d3ded0f8bdc6459ea5757ebd94d9faf2.hip | 79 + ...ac1a2ecf9a487809e46faa92e267df2d47de91.hip | 79 + ...ca79005067e20e4eed5a72ff9187cde702cd1c.hip | 144 + ...cb354dddef6e99e4ac843f2adafcddfc58d520.hip | 79 + ...d12033d59ce2799a2a024e5d9232325ccf1320.hip | 144 + ...d3b034a2d8d0b83c0aefa4faac6c3f28ce737f.hip | 144 + ...e2428c5447aa9a78f79f73f31cf685c586872d.hip | 144 + ...e8aedb7b7d77f44a46b2e9b7a826f245aaf4a7.hip | 84 + ...e8f0df0c54ce619e5b66441b3c96a5e18b05d6.hip | 144 + ...ee0083f6df962c4a754cd3295b1a436c590a0e.hip | 144 + ...f74764c3c3284fdd1b67d0ea781c2261ed0de6.hip | 144 + ...25857454eaab2eb664aef7a0849ce12c32fdf9.hip | 144 + ...37c76137df14fb808ade8bd6837045f2aaa5c9.hip | 144 + ...71bd8b7c270e1593871b638288a4923342c446.hip | 144 + ...d88a03cd3966dd0cff550065f58c3ffecfff6c.hip | 144 + ...ff94e3c787a7b06ffc90c25777fa74f225e32c.hip | 144 + ...0a759dcc92028b4c6f317fc230b98cb929e806.hip | 144 + ...1b12f9fd94e01aaff2c0da4f35f346822087e4.hip | 144 + ...6887daf6cc092e7422a17882488e59cecfb643.hip | 144 + ...7c6c80fcec3eb8b0bef50ad6af6d27bf5447f5.hip | 144 + ...92491c5a6dfc742c2be483419a40f6a7a7ea56.hip | 144 + ...a71615a088e972c998f9c7cb44566c268c5124.hip | 144 + ...ff035717140f7385282419598cb4fb2881ce8e.hip | 144 + ...1a0718891596ddac1fb0088637029233ccbe60.hip | 144 + ...2a156e9eb935555ab14a84461959b466c2fb5b.hip | 84 + ...641230fe9a50a221047f7a1df8a370f72805b9.hip | 144 + ...c363e11d202c6d2f4bb753661c5a2043edc0ad.hip | 144 + ...caeecbc01667ec6f5599358a0a20423aa9a00b.hip | 144 + ...f39b453505f68a5091f68b1c3de48369d1e7ea.hip | 144 + ...ffca078cfab8bc6c4ccd1cc8994a1bb4a88ea7.hip | 144 + ...02e718337eab7d47aa65cea7d3c5f641484520.hip | 144 + ...13b2f3bd8ad51315aadb7f63737201898adca8.hip | 144 + ...3981d9e7af2ebc0f91e61ac5e25cbe68c95bd8.hip | 84 + ...4fda16133a0d25077967b05425f9128e1fe1a5.hip | 144 + ...538339c21c92c53d237865d72debaaf2ee5075.hip | 144 + ...95316f0dfffda03e5296b959a49ec3f3c48d67.hip | 144 + ...dfe927fd64a564c5fad537fb7c41ee9c94c2c0.hip | 144 + ...e60b3ab7477f9edc8576a8bf43e3a62b8d5ef8.hip | 144 + ...f794c7023cbb7e35f1fd1ae45bd2377bfbc520.hip | 144 + ...28931bf5cc1daa6e106cf60bb21fa1aac6b1df.hip | 144 + ...2c8c3c1cf6c33af4574099e9b6ac54a55ad776.hip | 144 + ...82150e93f547e00f13cd8984779bf49b91e50c.hip | 84 + ...9c663be0267c009be4814e9e4e7c13ec999411.hip | 144 + ...ae52ef937cc27c544e32025ea0dadb7fad982d.hip | 144 + ...b74acd9abfbd1c4ec2f4c718eeb92a0bca7bab.hip | 144 + ...ba94794a14f0f0022af6f5f3c16e1e16959d4c.hip | 144 + ...1751b1012b90f7b57f8591cd06ae1fd27d9cd3.hip | 84 + ...66e7aa4b263a811408b285213e47176ee2bdaf.hip | 84 + ...6b3beb57b30afb30636f948e3989b346b38d20.hip | 144 + ...89852b0cd3cc030c78b28f2fd5b6b0546382a4.hip | 84 + ...8b96ad691a85eebd18586db0b62b8911016d9c.hip | 144 + ...c3fc96d2bebe546dce6ebf46e5c7a519959599.hip | 84 + ...ff04fcc273e469737512893ea3fb5876ac131d.hip | 144 + ...01c56831b4c6428200db6318638a2129bb197a.hip | 79 + ...36d5dfc0f939ab9a4064b403339373caf35b56.hip | 144 + ...42c4e3aabdf55405b3ce09ce1899245ddf11ad.hip | 144 + ...5722b43cde5f37242edb071f639da7c4a0bd48.hip | 144 + ...78b9aa31429d23a93cd953cc6a2fc5f43d0d3a.hip | 144 + ...9a347aef8a920e3b59d5ffe71fc5bfe002609c.hip | 144 + ...9de13222caec1483207d4a54249f8da4f9c151.hip | 144 + ...1cb49c1958fb4342d79f367ea93cf2b472f785.hip | 144 + ...3834d4d3fe76e1745e4482c6b51b550c6f3dfc.hip | 144 + ...513bff5c1da6aadf11d2e8272a422eabff21bc.hip | 144 + ...6863cd93d1b105a617d0daa1d4f37d7fb6b893.hip | 144 + ...68cebd81ade762c2f92fffc0153fa7a2b91eb5.hip | 144 + ...6e888c52d0f4a5847d7515fcc66208b1ff40d3.hip | 144 + ...7b3e1dae9bfb2e89398706508f8e01966fd4ea.hip | 144 + ...d76cca48b71dbcc9bd96734787209fee4c9a74.hip | 144 + ...e50367b62bb09071e28b44235a7c112645a706.hip | 144 + ...ecb6347009f6a5d5530a6acf90f9f40288cbcf.hip | 84 + ...2b116fd5065109aae46ee547e4f49ad0e9d6e1.hip | 84 + ...4e76d89b175e1d9fd2e9fb908d5fce1ebb945d.hip | 84 + ...55ed15ef58c941e06dda890aeb530e28eb7bba.hip | 144 + ...672fca51de618e3441cf8764e8e83eb782f2c7.hip | 144 + ...68c2f9a3acdd787b81be455cbc7836c8bfd90c.hip | 84 + ...89417a043556970f72eebd48b4f3e7ac15377a.hip | 144 + ...92671b6ea99891c0d69b1c793f4d131b9a82ed.hip | 144 + ...afb881e34a3794970a1282af740b3f19c138b1.hip | 84 + ...ce6e29e1d3060c3086c08fe27b471e375f9c75.hip | 84 + ...d9d68fcee021437e13ffdf94d78252205f5a31.hip | 84 + ...2647b5982405a48e8c8888552a4b89386ccdd9.hip | 84 + ...2efefea81036641561bed80c75d77651176f74.hip | 144 + ...3153af7bcdba33115a0d31f121fd76be2ffbcc.hip | 144 + ...532fcf26f90c82a792cde7943634f667c1d033.hip | 144 + ...90a0186d8b8004e3f19886c7992c8e04d0e066.hip | 84 + ...9585ba1c10acf67115c5899b3546608541820d.hip | 144 + ...b81407c8a2b3cdc5fecf655b3ad64d5d729cc9.hip | 84 + ...c7910aac798f0555e9e505ad7f177c9fbbd92c.hip | 71 + ...e8cf70c6be969ecfca675782c860b5b75ac089.hip | 144 + ...efed50a89d80c22b2c8c3d5ba67d73c3d0190e.hip | 84 + ...32a2d9701e23dd930119c4ee8089042b5b0ac5.hip | 144 + ...3b2ec99fa7b09c7f78dcc3142a661d686044ac.hip | 144 + ...8a0bb89a6f05289c0405df5126fa0cc16252e7.hip | 144 + ...93c65e5942a2f43f2e491547add02777dd2eee.hip | 79 + ...9bd38b8f9009d932ec49204fdea39a52885246.hip | 144 + ...aeedaa7d50f1741d618fb6c573529eebb075b1.hip | 144 + ...def49859c80c6b3ba18eb2fb4c35c72abc1cf2.hip | 144 + ...ee6b9427c164d78994150305a47f73954a67c0.hip | 144 + ...0e0147a92061d32608a34e7b47bd534eb787fa.hip | 144 + ...13a4c8d169877da6408584dc1f20a6f7c5e3aa.hip | 144 + ...de401aa76cb5425563cbbdb0362748148da3ca.hip | 84 + ...007c36231ccdae12f102eacca1f74b0711b9c6.hip | 144 + ...0a2370f2a320484d8f9f21e3197425c2dbe9ad.hip | 71 + ...1dbc9c433ce8ec33ace9e62550261d613db582.hip | 71 + ...3f4cd28a4c06cc109f6a0798a77844bcc750b7.hip | 144 + ...661b5f30566d1f159f060c264849c7ae4772f1.hip | 84 + ...bacd06455ab20eba78b389462946716b5819f6.hip | 84 + ...f309b923172f4c0fb38d9b9f5325b33b4877c2.hip | 144 + ...f9b9413697d6f4573c6605bff6f58d027c5016.hip | 144 + ...fdaa9266a5a464009297dc59db92504f8bf1a3.hip | 144 + ...0c699d9c3b0ed62097e38ba05e40e815cf474e.hip | 144 + ...588dcb2ef86677ebf84e406eb802e9921d1f1e.hip | 144 + ...bb0bef3b388867e75d7a8a187b8b4b650a42ae.hip | 144 + ...bddf533661642d84bf5a16149692d5a892182a.hip | 144 + ...cb7492feb79e27e0bda73e57ef7dab410e2bb6.hip | 144 + ...d4068ea93fcf4df463e3bf3a6898d23b65da7f.hip | 144 + ...3186dbad604763008e0204a1ea90baecef8877.hip | 84 + ...37f1bc50c4a65dac09ba56b701256b701c4322.hip | 144 + ...a055e5c3d6a953d470db5dc21449766248058a.hip | 79 + ...c24f1f9009e46afa3a59193784cc2575f79056.hip | 144 + ...ceed95b0a0a01f844678717c88e0426fb503fd.hip | 144 + ...32b11429034d96d82c82dbfdb69e460ad8a564.hip | 144 + ...e7df31541c3aa919e9825ad7dc4432f9a03c0c.hip | 144 + ...ff174ff2175e9ec22ac3a0fa59dd7713b79643.hip | 144 + ...11733062ed30b876f1d63bffa642d77e258dd6.hip | 84 + ...207f4b6e7fac27d6c16493a5373f448a2aaae8.hip | 84 + ...41814f76107d74ed069ecec99a248676487eee.hip | 144 + ...d5c8a4988efe60ef7943ecd73e18a28a736583.hip | 84 + ...d60c8abecb3bc9b84b0ea7851628ab17d8b0b3.hip | 144 + ...1691f01cc7f29affb88152dd48c7a484315dcd.hip | 144 + ...1c1fdc4206bb952b2fea675f24e3b09f605eef.hip | 79 + ...3c51948cf8584900807998da14d788039f53b9.hip | 144 + ...5ea67de101135ed5fe04f5cab1ec1d7b3714bb.hip | 84 + ...7fa6780d9e6bde10aec10a875c039fdbbc652e.hip | 144 + ...86cd75411e61a8dbbaf2b916e62f4f5f99104f.hip | 144 + ...d5f2ec83b3331654e37ea0b44d88cd98abaa37.hip | 144 + ...f747525ad31e76c88774fb2208e470da9c2310.hip | 144 + ...221590b90c48d3cf259fb4e834ccfaf7f3209b.hip | 84 + ...4f19363ef26efd36f0436cfa9f84f181a8824c.hip | 144 + ...6eb8c40e3146e06936f3141b2c4d92a578ddec.hip | 84 + ...baaaf1e90a075ab802c6e7d97c4b1605c8bd72.hip | 144 + ...c4ebd1792c781d219bd21b691b575f64635730.hip | 144 + ...d11aad7b666f500f68b264a2fcca6dfc5f1a05.hip | 144 + ...d4630876785655bd4950566e81ae0b645c0d3c.hip | 144 + ...f77aeeafe4b28f314fde5ebccfd2a554872781.hip | 144 + ...fea611f3c253aebf726af3e5fdb7e63e18e13a.hip | 144 + ...1a4425b411596c46c7032f6b83d3152a0e0cd4.hip | 144 + ...3e897098539c3466da9d7a37234daf16476277.hip | 144 + ...52dc38d26f6badb7a9bcb5ce9124d54cc45ed3.hip | 144 + ...5bafb551768855c8c01faa63e44764ebe6c110.hip | 84 + ...5c3549d067464d186a99b8205317cc000d4898.hip | 144 + ...73e3d855d28c54af612ab950b081302891d56d.hip | 84 + ...7768cd725813f8111d265cfdfea7f42034e5e9.hip | 84 + ...7b89d8d625b8244b5cceaa4d3e5fc5a09c8989.hip | 144 + ...8d5ce564c3ae1eefb54e3d41dde2604560ef4a.hip | 144 + ...9ee1f1b44d1a8fbaead65d8449413bb616d15e.hip | 84 + ...b255dde1a9d915e582ee2a83de7d83190c6a24.hip | 84 + ...cf7068183421b141ed5d6e7fe902d06b6492a1.hip | 144 + ...dc02ea7e0908cf0bd48034f5a49debfaa36219.hip | 144 + ...e8e1ab8c63db96843054bb7a98d708ae6a9c44.hip | 144 + ...fe3e8f4add16a088fe44458353fa7c0c4f9658.hip | 144 + ...047b5544acef40e39932672cac6f562e200948.hip | 144 + ...21507cf219fe608715d4e5bb6e5764022e2d61.hip | 144 + ...2b0dfbe3f615b1d164290799b2457437a0044b.hip | 144 + ...4a947a6c2ba83a5b1cb7074aee0bdac6c9c64e.hip | 144 + ...5dfb45658df8f1ae8dc0738ac9614740f2576c.hip | 144 + ...7f5328b035ed59a6f05dfee31edd704c4b07ee.hip | 84 + ...87ddf65ce4ed2997583e20fee9f201e86633b3.hip | 144 + ...f94f5c65c37624f5458c165daf83517d9e3c81.hip | 144 + ...3c44dd85077e6b12dd06fdcf6b11ba349e1866.hip | 144 + ...b9b96edda151072215502cc2b606bf1f6f0b03.hip | 144 + ...47fef2c06ea581b0ab31af1cb0556c572696ad.hip | 144 + ...7963e1969301abfa61d06afc97faea2bb4efb1.hip | 71 + ...86d4bf54b3a4a9e093360998b2059b3c03d072.hip | 144 + ...8a70d526394e254274df95de0727850820326c.hip | 84 + ...99e28aff2fb168cdc3af7132dd7fd09c2e1ced.hip | 144 + ...a4d71b31c451a50df7996e3db864bc3c3882ed.hip | 144 + ...b92b4e249195ac3e0c74d246585a4c9e0992fd.hip | 71 + ...ed7195a9443c84956c3f32839cb3ab9056bdfc.hip | 144 + ...14250fce818584291c69a5f058a58cfbd83df9.hip | 144 + ...3699a5daa14ca2def07489e0b563149bc403f8.hip | 144 + ...af6a7f9e5020e8d0f0ca0f6258001f6ce592c1.hip | 144 + ...cd9f7b08cec83736605af63d9fcaf463a1aea4.hip | 71 + ...df4e13108e043361e9528b71df56f04f696a0c.hip | 144 + ...11dd5ebb989503a1c182684e7f247e2f8cd9c2.hip | 71 + ...236be9da05a07d11cd28034d90cdf89941a172.hip | 144 + ...5e18f6333ed2cce509f07cb8bd5868951d66a0.hip | 144 + ...6785392af35e27d6697b584cb6f17a766d3fee.hip | 144 + ...6bc2762b95d550485aa720edaf71138d94cd07.hip | 144 + ...8da3e6ab050262b659c801ccf9a14787d7f176.hip | 144 + ...96f0ac76f117e66eba97cb990c2350561ec2ab.hip | 144 + ...98bcbe900f8c141136d18c114b02fffbe8bca1.hip | 144 + ...99b2625adffa8215276bb88fc65bae944b846b.hip | 144 + ...cf2f892742b1d236d2b31a8185c6869126adad.hip | 84 + ...3e7c8969027d3316875f33dc50fe022e05ce37.hip | 79 + ...e43f8b629e7039f57b95866d5777273377470d.hip | 144 + ...e746990a2032f0363ad9f9112cc994983f4706.hip | 144 + ...f767e7104cfc8322f26df35907fbf04b8948f3.hip | 144 + ...1b0f85e085dd0769c566fb16aafe5ab5952714.hip | 144 + ...2a2d78176e3f0a78e3ad78217e75a4430c0de5.hip | 144 + ...65ba6dba01da9caa84ba89453b61d81376763f.hip | 84 + ...a3f45d0be2d1119cccd0af042a3e8adeda2ed7.hip | 1965 +++ ...bf88db44aa5f884438288a325270d29c7a04b6.hip | 84 + ...c459e57bfed5ec7f40ea4a4dd9f72f3ad7a709.hip | 144 + ...02609fb803ea2697e2c2cef35e6f923d2578cf.hip | 144 + ...0b822743e0205f60521d38d7c64f589fdf0f58.hip | 144 + ...21263e16dafe79b9fe2f998847296e575c14e7.hip | 84 + ...3ef3d5ded0dfe2a0bafb52ea8f841658db35fd.hip | 84 + ...498e418ebbf33bed58b4074d1edf3d9bdd07c5.hip | 144 + ...a23de9604b5d98fe02529075bad995954c12ca.hip | 144 + ...b03461737f1e359f389a8d297476f9b60faabd.hip | 144 + ...c6e599144a093203fd7f92ac6d3c2cd7180d49.hip | 144 + ...e2f97d49f015b9af0b186801e939c6f357a0c4.hip | 144 + ...f893ee660d37fba7eaca452ae65b3e45a73087.hip | 84 + ...22f2d99804198c61251b4629a3f18ed3dcd42e.hip | 144 + ...33ce1fa113b221e5303b4093c2c4e748ce8298.hip | 144 + ...42736d4f677a59a172bd6f162616a437696351.hip | 144 + ...7d7888480b83c78833214b32e10f37a6e20301.hip | 79 + ...9130607a2d24cb0662a47e9cf12c6602143838.hip | 144 + ...943fcc2e64c618fc1415b3f1a0db4d70aa8494.hip | 144 + ...daf9d4270d2ac61c299320e06ba73f44730364.hip | 84 + ...0cad6ad5b172e51c569e84cd54a19b4eb0ed05.hip | 144 + ...13a6d0f8c798c0c4ba4ad202d081899fe081ab.hip | 144 + ...6bc5faf18be193212217788d476ce6fd384bfb.hip | 144 + ...7faa0b33a9aada86f032174afd40d18efa7715.hip | 144 + ...81f8cce0d77dec9f977b9eeb0778b70a13fa75.hip | 144 + ...cdcb750f382fc7828a9886585f50efbe5be735.hip | 84 + ...d9fa7c2e13d0bad5fddb2b5a316bbc09d397ea.hip | 84 + ...da1c96568eab89a8f6498f8bb23c1223cdc7b0.hip | 84 + ...05aca3520b171bb82d10ad70fef44f28c19776.hip | 84 + ...4a573ce6b7d2f90aede543939315561cc43177.hip | 84 + ...588bcac681a5d69f252d7523a3681a0c6b6181.hip | 144 + ...81430c92864c29bb9f409e7c27caee1de00749.hip | 144 + ...d5c3c86398f6ce55abc90db3e362dbf9f457f2.hip | 84 + ...f7ea0aabd069362ba4bbd66623cea5b6e1a6bd.hip | 144 + ...0ef512b7862837f54acbc3b21e135a192647a3.hip | 84 + ...22c973581930ab7a4ebc90b3bf1cdaa229a87f.hip | 144 + ...411df58165946bf02942b597d94de7dd856987.hip | 144 + ...6806a4598c885e517e664fc8280c59ec3cbf11.hip | 144 + ...73b7c710d418f44dc2b41bec5905024334eae5.hip | 84 + ...77d95cdf45f6fec95d1812f2ef183a75259e38.hip | 144 + ...828c7d3f5574690f12f841c27f025206e6165b.hip | 144 + ...84fba2eec5899bb40d49d4508196e6be1ec1b1.hip | 144 + ...e235e31d6955393ac8e825bd69ead70687b7c8.hip | 144 + ...f860d42fdc2cc6bd743d53ba546e332c22fedf.hip | 144 + ...105635385fbfb5d2f330df83ba6747bcb27f6d.hip | 79 + ...4f9af5e5ca519b21b71a54acb49f50b4999c47.hip | 84 + ...511de2592b6e350737e44865e1fed6496e3f32.hip | 79 + ...632f996eb63fbe4bc5748c5897b775087446a0.hip | 144 + ...6662cf1c9900a4334d2cadcc5f5ac3ad355f05.hip | 144 + ...73457ac3be01cc1595a015a5f598f8290c77e4.hip | 84 + ...a07ecf1a59f72ec6bef3e970d7f33cf54c5f44.hip | 144 + ...c142d869ef940ca876c93033ad53b576ed34f2.hip | 144 + ...047ea90076e3b0a3eb0586d49b9ee74ca6d279.hip | 79 + ...0861e81e5acc523fa680534eed757b7b4a4e1d.hip | 144 + ...2f61bf31dbb5de5d7039d5ff2338068a759b68.hip | 84 + ...3132e712eba8972ba444c604f89e01c5b84cc0.hip | 79 + ...5bf652702c2976551778b9159e09188575c63c.hip | 144 + ...6b3eef02b904304348b9d35f715b639d63218f.hip | 144 + ...8e4c1ca112afec494fbe47a85b553302c43395.hip | 144 + ...914c00690ac5c4f89cdbbaf00732ba66c5c0ef.hip | 84 + ...c9b46da8774462de8c24e14b12df3ed596eb57.hip | 71 + ...2013527a0266ad479715ee3e6ae01c45de29d0.hip | 144 + ...410fd9a4150c33186a2a365d06d8f6ea621c20.hip | 84 + ...5d90000b55ab8b6055b1934880fc6c4870b34b.hip | 144 + ...643917fc970c043d1c80d8d4b17ec92deeb8a1.hip | 144 + ...9668a3212cd00edaae871758be30a5a1fea589.hip | 144 + ...9e6b93baae25dff97a0bc9145a8d328ed3f317.hip | 144 + ...43da478310245e19e6c6a0d9ed7ad99540b3bc.hip | 144 + ...6ef175029a43e64164176d4eb212baf9d27bb9.hip | 144 + ...8d747083272ea657604ac84867ecea17bd65da.hip | 144 + ...938733446b6c0dcd159719f08d04a9aa467967.hip | 144 + ...b3225da1e1842f83592971a1f62a0fe30aa9d3.hip | 84 + ...60282ad39ef034fecbdb74acedfb48620b7dfd.hip | 84 + ...835ba70606c769e56d19dbfe74061361aa855e.hip | 144 + ...95783ae8f0034692efd6563f789ef03fd0f4f3.hip | 84 + ...d77b228420a3ead919474ec9c6fb2800f86890.hip | 144 + ...ea90eb5a527434c1740933a1d2dd863eccf14c.hip | 144 + ...f90358e522d7bb7c76c3a2c6010f0f38788bb6.hip | 84 + ...03018e71d57d3266fc35e2e18a78faa3dd52ce.hip | 144 + ...8639d44a4a8372a627a7c31e9527c8faa26f97.hip | 71 + ...c2000d32c230a57a6712f27bc0fba02722f5fd.hip | 144 + ...0bfced8745fbd9266207463fb41476dc23afff.hip | 144 + ...1d897ad17d7f6db2741b396e6b85a9b8f35286.hip | 84 + ...5e61dad8f63fb973cb2eb899c959e400622652.hip | 144 + ...8458c5a0720ef152848713119ebce6d76db6d6.hip | 144 + ...9071756e7d0582eb61ce6483fa3c988d2e10b5.hip | 144 + ...e4d2c757e4b8c366a2c320360e21ff0ef671a8.hip | 144 + ...f1ef32c4384ec26f3dc5e3af6a74fc8cebae92.hip | 144 + ...f2e2b108a53308a0cb6c123c8d318cbc2eadb4.hip | 144 + ...f7634d29bef11fd466b452a46b0612f38c949b.hip | 144 + ...0c484c2a366258941ee0051e139ea716a9de2f.hip | 144 + ...1a8bdf9d63b112e7fe5fa7e8835a6789cb8ecf.hip | 84 + ...2454f2d82184ab0491ea0675750c6ec55d659c.hip | 144 + ...2b4f995d622826af5d1f2bffa7ba68467c841a.hip | 71 + ...5a523f815eb822d66162d4feb75fe0bc50b648.hip | 71 + ...6c5836ba118969c4ba89ed62a98dffe3105738.hip | 79 + ...95d39cd62f20622a31f11a292ed175abb5fdf9.hip | 79 + ...bffc159b0bb826ba489ae763dae141bfe8e802.hip | 144 + ...c9e5384809b21f39e78bb2e43af345a9a21d19.hip | 144 + ...fe68ba10b3480dddc9866c51ca8b5efe962cc3.hip | 84 + ...3a980a26682d879c3a3425f3ba5be3f5761adf.hip | 79 + ...45129fc4995abcb8f880692f11c6186fc01641.hip | 144 + ...833fc01e88bd8e256ef64ae8251dd0ed10720b.hip | 84 + ...97c457144cb63a9c6c3d6be613b47bd0df9928.hip | 144 + ...d492377add5c8f6d0d2dbf9ee9e4338bbd9f1f.hip | 144 + ...e344010d49f7f9a6caab2cb84be7f87d2d96bf.hip | 144 + ...f6c5be53732eb1939a2f93232af7dc011dec1a.hip | 144 + ...0bcb241e5a1be1d35366461408d06e095a26ef.hip | 84 + ...3326e055da32cc979892a2fbd0f7b003cb9f98.hip | 79 + ...3af90387f1d227119c5dcd4b71362940bbce52.hip | 144 + ...4050988e5790a28dbe10b4c20e14f10f6cf85c.hip | 144 + ...49a9b0801a06dd89c7f7182d7590b515df1592.hip | 84 + ...50073f6dfeb7ea77d5dce288a1d2f08f8f6362.hip | 84 + ...5317b6cde327a842170ebff20c2b03d81379ff.hip | 84 + ...8169ce4b4b9a17ac96fbb232e6a93f22071ab4.hip | 144 + ...823c3b99e7c8d1cdc39a5dbc7365a383bf9ccb.hip | 144 + ...a934408c75da5479cc41f96b98ea7d333635ea.hip | 84 + ...b6da1095bd8669c0e48b5cd808cf0dcefa2674.hip | 71 + ...0bda0feaade2b554d648d72f219ac9c389bf09.hip | 144 + ...2e75e6f659a500dd3cf2cfd65118f111342119.hip | 144 + ...77bd7e89ed832cc31b2995566a49bec6e4cb52.hip | 144 + ...7aede7762a524a7a424cc4dc46e43fdedf73a2.hip | 84 + ...808da5c2514806c2953bb77d5692e5d7c97aa3.hip | 144 + ...82e3c4e445e1e02f14435e4ca01a90850139a4.hip | 144 + ...9756060ac0e73dbcfc58a9222a78f0283cd029.hip | 144 + ...aba3ab83239e474412fcf89fe0fbef97e51bf1.hip | 144 + ...f351fc2c2da4a8e1760a3affc9a5947c6b3bda.hip | 84 + ...06f77a4054ca615d96636c0e2eba2a89850142.hip | 71 + ...1f2d1e57095f756ddd11e8e9d4f6f253e3ffa3.hip | 144 + ...23a26e0a59a8323dd97632e610d24624143fbe.hip | 84 + ...43460c011b8d5e01ea98c9b8ddce962de59a96.hip | 79 + ...446754d7000673779d15d3e73039fd3c10a720.hip | 84 + ...7b637e0313cb423b22cd8844cc2997b3ff73e4.hip | 144 + ...9a04b7f41dd6f0db017157a44790f35c626e2d.hip | 84 + ...9c659ba43bb907fd4e3e36a50958288bafd1a3.hip | 144 + ...a2b905c4ce32234c2af62328adae6b1f9217a8.hip | 144 + ...b33b5442d2e0948762b1f2147a321a9d6907be.hip | 144 + ...fac5a83def98340c8786d55a30a98ad68b9eed.hip | 144 + ...30f50071113dc4ab59468d568ac9deb06b0342.hip | 144 + ...43e401abbfb1b6737e4dc822f68421abbc648a.hip | 144 + ...8b4260626beeac76c26dbcee3cba1457b30e99.hip | 144 + ...a394a09c8691a534ad2219bedf73724b6dd5ce.hip | 144 + ...ba937ff6d0302ab013db7349d4feb914107f1f.hip | 144 + ...0247e301a7b076b6ec8a778c3b47e330638963.hip | 144 + ...32f2d658f1f69840fbad511ce8a3851c859d52.hip | 84 + ...55a23a0f24ff7062a4c286944f25d2db3e20a4.hip | 144 + ...024440e780fdf9ec94deccc85216d8bbb5788a.hip | 144 + ...3b7b04496e4db7c1ba2436485dc7c8a4c88448.hip | 144 + ...76a6de0e2612279e0ed64612f7393856bcc9ac.hip | 144 + ...c8e4d5c761fda50e010da779e8e4730051d403.hip | 144 + ...f0200092b0e18d57a9f5e512d565f1c0229436.hip | 144 + ...08502fd29d3a24b32177bcea968121ee809115.hip | 84 + ...10540b50e95e99a5cccebe47d9d3a83093c2fb.hip | 84 + ...1104394c8bef8d4ecff35c1409221e723a5a8a.hip | 84 + ...1731442b756308c0a869f21b7b8b103aa613e8.hip | 144 + ...222e158484773d2257f4a31e3dfbdb68336a8e.hip | 144 + ...63272d25bc2db2ffaa1fea87648b45ee68d408.hip | 144 + ...9df310195191895005b30151da8c1afab6c82f.hip | 144 + ...a968898f0bc6366313e41eddb5e3a3ed12dc98.hip | 84 + ...b807c48c472e9b1311a6037cd98e21d6706889.hip | 144 + ...c3760f5978baf9780ce4587ae4c768af0e49d1.hip | 144 + ...c4b866692ba5c3d115482bef4790733863c1fc.hip | 144 + ...06cc121ce8955ed59ea3b12b858ee2e0cf82f8.hip | 144 + ...0a6196b662a1d3dc7441a9536d825dc356b95d.hip | 144 + ...1500dd4c41e4d68834814a48a639f5ca36a2fb.hip | 144 + ...2a86568f89a5a5a165cfffbae9ca6949f2477e.hip | 144 + ...438250078ba2a47345ec4955dafb4e4de78a25.hip | 144 + ...527660fa7aeb9a951a9f2fc3c53989bd141c48.hip | 84 + ...5fbcb9e503e68fafea08abf86a4951f440850f.hip | 144 + ...652a27e8605cef59c8341813b68e7513be23c5.hip | 84 + ...7e27892bc57f3dec0da24f94f2a483d6c9321b.hip | 79 + ...8a311bafd1c153525393b252e4170f8aafb370.hip | 144 + ...099fcfc218ffdf69edb4f2f0e46121bea9fafc.hip | 144 + ...746071156e9ad46f403a539dc237e0a44122a7.hip | 71 + ...e7c1e5f41a451c7baff54f7238b220f1bdf8a1.hip | 144 + ...00f0af03743dce328486f8fc805dd30bd6da31.hip | 144 + ...08103188e27b3bc55dce0c1716c0b4d32d6494.hip | 144 + ...2d29c85070f488a14b1915f948e5fd69019c99.hip | 144 + ...4932e2655d7b32704be8de9a63bbd8c3369f02.hip | 79 + ...5a939a2491166dc520e9a2b9de7e43671e0c2b.hip | 71 + ...5ea796c8d97bfe3b7c9663bf15e2e5e7696235.hip | 144 + ...807a8e90bf1cd839f32fd718afa6469c35a4fa.hip | 144 + ...9241529745bf138552f49d9a93db418663ad65.hip | 144 + ...c2db98d8e2e690f499f41cfd5afb831b756f54.hip | 84 + ...11c54e6a6f9eec378d8b661121066536195d3a.hip | 144 + ...1425a006aeeff4d69c8570cb6bf1e1427d2c21.hip | 84 + ...4121d3bad1d448bd413718fa096f54faa12e95.hip | 144 + ...6f83cb96d0313abcdb24955edd4264df72aed7.hip | 84 + ...7f7e626135cc9176a295f3d1f336a7c3852688.hip | 144 + ...8399e756ed5026baf3ab78af17489dc07b9532.hip | 144 + ...8d28c958c0a831a615a4811d13279b18db09c4.hip | 144 + ...42b78913a853a62dbff8b99d9ae3fa458f461d.hip | 144 + ...6662dccf2f650bcd8123c49006c759cd4c0ef6.hip | 84 + ...7e58867c46d96c9bbaa96eaaa9f93595c9e099.hip | 84 + ...a0a960541bd8a2dc6741579de685b7c0a5f6d7.hip | 144 + ...7b70f54cb2778b5ce3df936b477f775eea8b3c.hip | 79 + ...8759ae25465c32960487375828e23c5f1ac869.hip | 144 + ...8bf438642e5d863e31145ada2a0688059aa5d9.hip | 144 + ...ad61bf8427a26775969f8a9166fd0bfb7446b4.hip | 144 + ...fe04467e87ec2110f60c7aea0cc9bf2ca07481.hip | 144 + ...010c9bf7341588f071f889b7a0b4dcc4e7a14c.hip | 144 + ...1b29d9888365bff0f109d897b508eebfd8a61f.hip | 144 + ...24e97d5ecba46e06d5ec1a9456c810d80227a3.hip | 144 + ...273a2f8e6bbb42ba0b0871b6c95abb34531f33.hip | 79 + ...a5ff72f22e0ad040a281e66b1aca0bf3a2aadb.hip | 144 + ...abcbeaa4d33d3150f2b0238bb62ebbfe960980.hip | 144 + ...b94d76503e13c911781169fbc378517332c42e.hip | 144 + ...bb367362fe2c4849ded728ec5dd00969ce188f.hip | 144 + ...e12dad9e3bafe177ed3c27c833825813e18fc3.hip | 144 + ...f8a89468cf9c8606cf12a930db062a83cd0ea0.hip | 71 + ...37d9dfb68351de2942e32f35e2ca1ce71edfa8.hip | 144 + ...422621a00ff79b2f5ec0dafb957c77693537b3.hip | 84 + ...67a8807c9451b09227c0f685c18aafeb062fd2.hip | 144 + ...92d5df4ba2e999caf6889a852db4e1ba078e65.hip | 84 + ...d3071347a0c98f3221104036f477aa13bffa4d.hip | 144 + ...1dca5feb864e8981387c2d07e62acef1730aa8.hip | 144 + ...2280997eb6f1d091094fc54cecf42b7c9c3a2d.hip | 144 + ...2643099365d0903c799585f41dc1a525ac9f9e.hip | 144 + ...6b9566559ed2b1c85f2bea1c55e72c41dc47bd.hip | 144 + ...f86f458fb4dfcceb7db3357fbae0dc15142a15.hip | 144 + ...fbb5ac9048a962a60f48886728220ae6c2aeaf.hip | 79 + ...26eafe76cca8e74e819220b6de1f4279d48e43.hip | 84 + ...4ecb47f9ebe8c2784976c3e9bbe4834b475cf1.hip | 144 + ...508b92f7e123b21658f6e17d624ffa87831fee.hip | 144 + ...5b3c218e4a7b459e54080e24c5b730221eac02.hip | 144 + ...b129e6dee6848043dd0e8fa812ae80fec4d014.hip | 144 + ...b3b682eab96e4e173affad75b9d8e73f1dd690.hip | 144 + ...e7cea6df8e6dd56194e1172f28943667f1c4ef.hip | 84 + ...ed3aaf24c73073c604a3b23bb4b0358b8e3490.hip | 84 + ...1454ffc1418dac641f63671e947d9f550b1f0c.hip | 144 + ...38bb80e9880335faaea81985ed5d0e713ecb08.hip | 144 + ...3b7e4b8c1efe59f79a15512716fce2282a79a7.hip | 71 + ...64c33870ebc329921cfa3867d58b1857421f65.hip | 144 + ...b0cee09d633b6f70febbba63a1e090522cfb4a.hip | 144 + ...ce3baac1e3ca03af0c3f4ee4d0158ad1031e9f.hip | 84 + ...cf0a9d5a5451da5dbf6075ccea45e4a140550a.hip | 144 + ...d7a9ca49c1149d46f6b05b0fefc41ecaeb6ea1.hip | 144 + ...f45927b6d931e31e2209685d787efa28eed8ba.hip | 144 + ...1cea88a2277b87d405025ba256272a1720f88d.hip | 144 + ...289100991d4c8c362f64c8f6c4ba395c2f3495.hip | 144 + ...3f3eb2f5eb1f3287879604892b1c230df85f1d.hip | 84 + ...45624dc6e33c477c73a155500b015b6c010de8.hip | 71 + ...55cb42b0096a8ae338ce100f86e378aa1a04c9.hip | 144 + ...a8c31f6d5bcaacfa4a21aed4d1d3caecb48922.hip | 144 + ...ba3cd44f78c950fe7ceaa5f0629dfc607b30f1.hip | 144 + ...ff884e176ec7cff86d17c6afe1ddaa4dd6007d.hip | 144 + ...143d88eaa0d9cfea856b2f3a57d1275a656627.hip | 144 + ...2557f206fd81d82a3b9d59113105040beb891f.hip | 144 + ...562e6c3af28b8478020ce3c3bf73c036001c93.hip | 144 + ...61b019e1398a6a3c36143fb84b5ff22c9f4508.hip | 144 + ...839660557dee9d5bcda9b56940ce23236c5f6d.hip | 144 + ...b2ea922daabbba131b90713e06d8caf5f30662.hip | 79 + ...cf565a5a1c4a09887c67ac3b9a019dca427ac0.hip | 84 + ...34433b784d1e405ade3378918641372a30bf6b.hip | 84 + ...5e01b4f2ca8ea10898c39d6570bd74e85f46ed.hip | 144 + ...7315955f555768f24585a50d75e216c40f062d.hip | 144 + ...ad30ff0739ab5dede67a96e859f8c474c245f8.hip | 71 + ...cc6893456a559c7d22714116022fc69b372266.hip | 84 + ...18b1fcee808b6cccd131418b6ae9e8bf900d8f.hip | 144 + ...18f690b6322588041bb467beabd8a7bc79a2e0.hip | 79 + ...357c5e9739eae136a7abf92bc38d3ac94753f8.hip | 144 + ...52ca6a3ec02f6559e4bbf1edde42ad2d127c26.hip | 144 + ...5e7efa263223148318ae96bd1929b382e994e1.hip | 144 + ...aa64439b80ff8dd12498b3e5f6b625da16e285.hip | 144 + ...db688a9189e1c47c300d474df946a248a63303.hip | 144 + ...18e3ab290263ed2576feaf22a1944bf2ddcb7a.hip | 144 + ...5b183c50dd2663dabe3eb8b780913b778c54ab.hip | 144 + ...60f6b6d0869740a5a411abd80108f729f810eb.hip | 144 + ...7b1cb14b67dc82f614831550f7deb0895bd7e4.hip | 144 + ...9461cdb5687ebbb7bf0be136071d70420c1619.hip | 84 + ...b68458076e6cb129d3ec793e95b91430a0c8a1.hip | 84 + ...db3f29d1940e59dadc357c040ea37a6ff208d9.hip | 144 + ...17a48a1677bd26cd48e512f1fc8830a8a551b8.hip | 84 + ...8ce4e14cf94b284ffa735fe03d923cc74c9fe0.hip | 84 + ...9b82a27571ac91e3631cbdb7e0a58155abf962.hip | 79 + ...e2326066c91452335eac05f25a6311376bd9e5.hip | 144 + ...06c6c37cf472ad262f53941611b5e60072bdf6.hip | 144 + ...47e039c003489dd528faf5d710e687321a3fd7.hip | 144 + ...56b3a2ff49f72b91a6b9c215df285f2798ad47.hip | 144 + ...77ac04be3a6cbdbfbe57612a469412812fb5b5.hip | 144 + ...8e3565f4c720e6c9691b0d33c1392936e2e7ae.hip | 144 + ...95d3c96b3f4556b9765fd0a3b5701b2fb10948.hip | 79 + ...e7c78e8f65be35e2753a0ad5123118555c56b2.hip | 144 + ...f2156a04b18bab55af60e9357f28d8a4604e8e.hip | 144 + ...09f2a7deb027e864afdfc9975d3ab93c5dcc9a.hip | 144 + ...32c5214c4d40c54ca2d02f0d4785c6d6902370.hip | 144 + ...462715ed5f192532760d6f4c66ff9d4e20e254.hip | 144 + ...564dddf8b492d80be54854abb8d1d831e42679.hip | 144 + ...5cd8fa559588f4264ce6192f2de3e3065365ea.hip | 84 + ...5e28a8a51cd435130ded2abc9fc606e522c713.hip | 84 + ...62b192a64efb60d5484798526278ac7a0fb9fa.hip | 144 + ...66b6c6b2ec3acb40ac1cda432efa1e4e62d9d9.hip | 144 + ...690e48f30657b0fcfa26fb3b9af3ef76e792e3.hip | 144 + ...c181996532676f2140fd026707135144e9d37b.hip | 144 + ...cc95831c347212021c0bab7b43acd7daabce42.hip | 144 + ...d82b58fdc3e5b7a7c20490ce7f5acce4e6ec79.hip | 144 + ...1fbbdc2dcf2ec81efce34673ee6c425cc16ca2.hip | 144 + ...68af1b2f104664fd05d21ad789aed39ecfa42b.hip | 144 + ...7eaffbff3c58183a656687010daa2c16cfc26e.hip | 79 + ...8d708d13577f2b92e6d5adfe952a87e0cf7be5.hip | 144 + ...9c8fb6028991321b09a990c2188d854d940268.hip | 144 + ...9ea3713aef9b916e1b38a882a45012930924d3.hip | 144 + ...b9871c220c0065d74bffeed4021d0304a9625c.hip | 84 + ...f4363f50af1e7ccd24751d5f5b181bf32c604f.hip | 144 + ...01680af41c8738089ff377147e0547dcad114d.hip | 144 + ...1737a13e24009bf1a5a4b780175043a9f2e33e.hip | 144 + ...66db0ff7b035e54f2c0e59acedc2131b722a55.hip | 144 + ...8a5f057fd5cef2df5f919f5102f47e86901e3b.hip | 144 + ...4fe2d739eca8c93fdcb2c105d4154cee6ca1c1.hip | 144 + ...548aa042c69bb9c59a8bf706b44028aaa41830.hip | 144 + ...f3ced9b5ddb0dfee8ed5e7df8eca0bbe273047.hip | 84 + ...fe73f04cef91cd2a0682e905483968ff80eadb.hip | 144 + ...1415463f0316ebe25ff2fda47c68cc54db3359.hip | 84 + ...24e1f8cda50f80988857611da766685da94494.hip | 144 + ...280c91d7cd8712fd533e246a6b0f758834abc9.hip | 144 + ...2e34930d11ff493007b1613993e01acc1af78d.hip | 144 + ...300e0aeabe337785d4c7b41796ce65df6cc42a.hip | 84 + ...3eaea4096c8f5bee16a64860432f0634a253d8.hip | 84 + ...435e5dd23e49e19dd313f9891ffec800ce74c2.hip | 84 + ...6f6c7c7655c34b7b9973ff357b0813f0a3fd7c.hip | 144 + ...7724686efd35731e5335efa949486c93ae26e3.hip | 144 + ...9e7be0f85656d012a6451b65f6c1d2613b187d.hip | 84 + ...ae3af78583258c4b13c11a442022e0e058bb85.hip | 144 + ...d7d145f96aa8958a9208d0c8887742a8c834fd.hip | 84 + ...e9e858abf6f77489f3fadc4ee81edacd26705a.hip | 144 + ...04c5910a2d0595b39a3f87652a9d1ef4fcbe80.hip | 144 + ...0a68220a7b621ae9817d7b77f55de239b0a4f3.hip | 79 + ...11bdd71351610d55916d452495e599960d0a41.hip | 84 + ...2fbc418e829f89bcb8d93f8afd2869dd8dfccc.hip | 79 + ...d4c005d723cdab9fbc307933c1257d114b539e.hip | 144 + ...f5017cc0f5c8c8dc71492e7765cf729c1f225c.hip | 144 + ...06b5b153ea6e8b1e20d9aad9d4633333fd98f5.hip | 144 + ...2e6b05e7e4de2cb23d815f8b2c8adf22131c0c.hip | 144 + ...4a00bd6ea27ff20a2903d619e1361b5e27672a.hip | 84 + ...5dbf601de5754c03a03a1a42395dc0766fb8ac.hip | 144 + ...9f3da698a6103caf25d785928dd9f814ac27b4.hip | 144 + ...b5d6e8fbfd92e9f7e47bda5cfbb0d4162a6319.hip | 84 + ...fd02981f92fbef6277c1985cc479c12bae9239.hip | 144 + ...1eaca3c37a82d19f8dc91f06764170069ca3af.hip | 144 + ...2e7f96b095ebfb66ecc7a75752fba2a63e4f37.hip | 144 + ...30f472f00bec9da0564ddc40e07112b5f9a117.hip | 144 + ...45948f2795293e72530b02669c4f549608ea7f.hip | 144 + ...4c03c916393d6be7c5181369ebcef949eaa763.hip | 144 + ...68e4d00295b294320b94bc777d7d34609127e0.hip | 71 + ...7393d55600c9892558248f4131fc06a6cf3309.hip | 84 + ...74439f42140cdda9bb0f78d995d741212a35f4.hip | 144 + ...76e5dce9af523422782dd25d8dcf6f25edc68f.hip | 84 + ...af664bfdf070362bcc91af77d1bc406f744351.hip | 144 + ...c48576f285325345fa1205e5e7e01787b74f71.hip | 144 + ...d4d46397a3749646b232b306688e52b8c6e584.hip | 144 + ...e4a98f150f3f9ab6f03b5fd0968c5454565c9a.hip | 144 + ...eca56234ff6fb4f23b9b24822887fd9a3d0df9.hip | 84 + ...ef4d120e71bfcfe61d67aa44d24ceb907c2b9e.hip | 71 + ...0c50a1fac82d47dff2357ee3ddbfa0b2c8d487.hip | 79 + ...69d06e3f32e3b6d28d3e54ad764b472741c193.hip | 144 + ...8720923c3452e3aebd7b9c1b4b23f0c35d7e4f.hip | 84 + ...abdafad0bf803223ba5e8f474cd59233dc48cb.hip | 79 + ...b1861e31df98bdfd731efc3d335055090d83af.hip | 144 + ...d3de43cc1f7588d62a10362f59d113ee818846.hip | 84 + ...e03571f1d2779bdeaf0a6a2d617e236d191c11.hip | 144 + ...e671f5defd76ca08614a7a1f184c36c0f1e2ab.hip | 84 + ...3b1ae63e127b6e6afe39e354d4995afc5faeaf.hip | 144 + ...5f3cf0f78f73df79665c26b20b0805615e1b04.hip | 144 + ...65e58c9f147498ed04dd51fe1393770603a6d3.hip | 144 + ...7dc0f356b630179916f8fc2041b7f1402b46df.hip | 144 + ...a9e9b7277bc90518ab92860bef2097ba96d982.hip | 84 + ...b2e63cfebcf84043f79be0321708cd159c62b9.hip | 144 + ...bdd9c3f496a27bde68cf86374999ff2dd53505.hip | 84 + ...c87b7d385e7b092e4706c464217b004fd8a6a4.hip | 144 + ...de56efe17f4fd36a11cc959320a5e43f1dc232.hip | 144 + ...0a88ccef04e81b8c684b695f7cb4310e448915.hip | 144 + ...15e4f16de26068cba30ef12fc29332d45e460e.hip | 144 + ...47f8fa40332c6ed12d9971e0b539049a871c34.hip | 84 + ...760de14b71a41882ec4a2c7362565af36d1a5d.hip | 144 + ...79dce18e49ffe024fe4cd0693ad3399f5edaee.hip | 144 + ...9a933b916285d9580a76df543cfafc88a536cb.hip | 144 + ...c2075f394acfb14fae7b1ef4304fd9b654ba0d.hip | 144 + ...d6da5357b67cc28aee4afa9523adaf055c4e32.hip | 144 + ...f35d82ceb4af2e07719c16109c6d72eaedce67.hip | 144 + ...0aded9d1baec3125ce8e176248cb146ca580fa.hip | 144 + ...1e1c969b57659e7e1367ac9ba10ed5ef5b69a9.hip | 79 + ...44435491aa68acb3217b0e693232c67641a2db.hip | 144 + ...4a5d56721bb1a1332a65882132a8c5763932ec.hip | 144 + ...6243c6850c0a2d2b7bf1476e12f95f187257b6.hip | 144 + ...a4d21931b9afcbd70b1567995d3eeb6f9308aa.hip | 84 + ...a883a36a76edb276a66c5d779294f170d6d4b7.hip | 84 + ...d34faa8b168e2ac7862641229e6146d3e28aee.hip | 144 + ...e530cbf6363a8f08a94728e45e88ecde299e7b.hip | 144 + ...f20bafbf156fe8fb80bdd84a5d2f3a4a944c1a.hip | 144 + ...1dcf3213efd214cc2ce8c9ba0027f991d241b4.hip | 84 + ...52b2318dbb78b1a82ef03666a35a623f44481b.hip | 144 + ...93976cb7b32a8bd28ce92fc13af00a3e21f737.hip | 84 + ...e59bd079f4d205b613056f975fd2b4e372ab10.hip | 84 + ...e7b11019fc2299d70869253877319b03388244.hip | 84 + ...f887556a3540609649744957651ca667b91774.hip | 71 + ...f915b4d9bd18a3c25a85917392ea4a5e88b349.hip | 144 + ...5128c6978449b33ce0c35b02a9e9aaad65ef7a.hip | 144 + ...2a2a9435103ed405dc1500d31652f1d431a49d.hip | 84 + ...3e5bf45ec5008aa3aba4773e68a78e122b2fe7.hip | 84 + ...688999141a72e61322140db29043ef9f7fbc3d.hip | 84 + ...6c89b7a04758b4badbf9695b316f877b8bb053.hip | 144 + ...8db08068589c6e4c096054d26a2e5be63285b6.hip | 84 + ...a89981a05963efcea7ba5c1e967638beeebbbb.hip | 84 + ...a8a323414448c50571a334f29bc0a38919b61d.hip | 84 + ...2a6ffd8a21d3e98342fd401f0247f62ca4e038.hip | 84 + ...44427df3ae9392c4fc4c25c232196828e70648.hip | 144 + ...82a30dcf702daae19bd6705864bfe36e09502c.hip | 84 + ...bd60bd2afee49b30a583c32a45ae9f2076db08.hip | 71 + ...03eec1cdd216d5c4a7ba977e2ef92a0d7fcc8b.hip | 144 + ...0bd57333c6839ccf5cf2e928edb996bc60c371.hip | 144 + ...1874a7633e5713720b9d084b6d1c6715a51a17.hip | 84 + ...208a6e8c5263e38f9ffcb062564ab61d2785ff.hip | 84 + ...35b4651a90e331fcdcf224282457e3dc038a30.hip | 84 + ...402a22ceee3b665a3f24edb98b8398c35c6f5a.hip | 144 + ...548ad36fb92d0963893146c8db20f53cbf0c8f.hip | 144 + ...67aea26852aa9a9e3dae76b906005ddf6fbae1.hip | 144 + ...8b347672451e8391388a400d016803f4c4cf8d.hip | 144 + ...940ce53998becf9bddf56df7d19894a7658168.hip | 84 + ...9b6956eaf678f7eb901567d1a515eddbedae5f.hip | 144 + ...b6e18b10d529eb6b32d7c19c59eaefc7184376.hip | 144 + ...ff49018f1c12b9fa31e523ad40b9cc162ba34d.hip | 84 + ...5ba79201a585bc091ccfc326fd24e851d1eecc.hip | 144 + ...6cd05288e1666f5c67fb87ad02ce660e4c589c.hip | 144 + ...b14cf2998a61611d1de2594e926fcdc378999c.hip | 144 + ...bd9c4f1b7a0621c67f3e964d946ce22fb2fc80.hip | 144 + ...bf8444c1c26b91fd490c7216f4d0f8aa0a1f1a.hip | 144 + ...cda610c235987e13232e828f8d86fa88030560.hip | 84 + ...ea83a47c6299fefa4220ed88f7a8e1dd938215.hip | 144 + ...6b4782793c6526bfce7362efbf6bf069928b2b.hip | 84 + ...6e26d4969bc6bbe9b092bedab11cddb3360c0f.hip | 144 + ...964a17f902257aca9d08c736516a2c67d9a0e9.hip | 144 + ...cc4399c5567a9495f17d54c712cc9e65e57521.hip | 144 + ...de9a7dfb1201b56528740e9d8a07b62710fcaf.hip | 144 + ...ffe9e21362afe9c3a407c09d5de186954931a6.hip | 84 + ...24d91c1fd6290a6cf8d52a3801ac6b921dc7d4.hip | 144 + ...2e68bd619e118292768f0925ccf92cbfa68415.hip | 144 + ...32094f5917e9164ee0f973ac6ec47245a69101.hip | 84 + ...89f267d34c9961ced63ad07ffea2c6d2911415.hip | 144 + ...54f09511778dd1779a839b0b194896070f69ad.hip | 144 + ...679919fcd292a2a69543de0db94e2985c9d364.hip | 84 + ...762476c7f2bb05dce92ec22c0acbeb03676746.hip | 144 + ...7fc33d02b1932235b8d152e57559060211d591.hip | 144 + ...a784fb478ff5b3f1e2da9765a3a777efda92e3.hip | 84 + ...a7ab44bbd9fbc97c7805860d5f6ac81d6ae468.hip | 71 + ...eb2edc7738d8d18ac359691da261ceaaf71788.hip | 144 + ...19133d2ed892745013b2fc5d503414cf0a4d83.hip | 14399 ++++++++++++++++ ...39e6610e41aff8d1ccdb66d9e84d3e48e8d379.hip | 144 + ...4929c433b049a8cf949ff476309a8faf5c25fb.hip | 144 + ...7a0276ec419f18f060a5186e6bb703ae434ac8.hip | 144 + ...901147b7188212b8d8feea15831a11425fe4b3.hip | 144 + ...beb9cb4e161f9dcff79080149076488d436301.hip | 144 + ...d366421e0b51c90fa53c366d47ed8d51b3a329.hip | 144 + ...05b4e7782bd0e29ca9f6d33fc59d4304136d41.hip | 144 + ...216f777feec4752f5882677b18168225da4b53.hip | 79 + ...29b93cee012c79d4364502f1d90f947c73641d.hip | 144 + ...85ae0a16e4b293b549bcb6a3ee52df7fccca32.hip | 79 + ...ba1183efe205af38e79a1b2dccea5fa515d02e.hip | 144 + ...ce1c9b00f160a17355d4583d49c47887ac33c8.hip | 144 + ...f96b404feac271dac8f4190180754480d3ba80.hip | 144 + ...413bdc825ae863d53dab548f2145dc0de8fd37.hip | 71 + ...55946ff3c15a44b9c741e9f6bbbcb5bd4c8577.hip | 144 + ...7a4ea3bb8905a22ae97a94c354b1cbe38093bb.hip | 144 + ...a578c0e7abf1127dd0370f06d7278656c93ab9.hip | 71 + ...c803342862aa30e23e5be7d84e611bc571c529.hip | 84 + ...e9ed84ad9be1627db7a66af9370679816c0897.hip | 144 + ...ead6be6e39ece0e5d44335083336f7f546d2f8.hip | 144 + ...36fc744dfb0d985c9113175e76c7ec1c935054.hip | 144 + ...742b9ac6749f189d597ac97d46d35189472c50.hip | 144 + ...d03e29403ad53d6d52e5e81182ea6ff5aff2be.hip | 84 + ...d41b6f578f3c903eb9d58ebfab62eb296044e0.hip | 144 + ...707d065ae152450f9def619ddc3dddb9089e88.hip | 144 + ...7ed4c885fb32a0b548186e56d64bab98071d30.hip | 144 + ...aedab8931f2eefb649b91e80145cb71b63360c.hip | 144 + ...e27c4081377f59363c2bf2ea8624217566d2d3.hip | 144 + ...0abf4e2b6be3e2c555c2134705b9dcaee617ce.hip | 144 + ...62968de58d9df7d687d671f37d63393f189321.hip | 144 + ...735b12d130ebf849ac5d6752e413ecf3e69fbf.hip | 144 + ...840be0741afa4d41fd4789c8300223fdc63ddc.hip | 71 + ...a53f7c6370845fa94aa9b395c52fd1900b62de.hip | 144 + ...fe77ca5c394a60af0313072cdd132216a52bf3.hip | 84 + ...20263fd84776f155519b3481be5e2c5b035585.hip | 79 + ...3c3bed2b584ea2031debf9f953f5f8f7012171.hip | 144 + ...71e663978dbcba859c5114ec675a712e343fd6.hip | 79 + ...8925f929a5b26f3544ca31938aa75b3c59d34d.hip | 144 + ...954a393b7b5a7131c13d0c4578443f468a738d.hip | 144 + ...a19223cf296d7fd10e15e2571e63c84a80fbb1.hip | 144 + ...a7fafd4227918e0c7f0c6ca3b2bd673cd07279.hip | 84 + ...b062527121e627871b3f1b2a94b96c42e51205.hip | 84 + ...c66c5b53f83bf1e023e81e9d51f0285b3ae731.hip | 144 + ...18ab272d7306689c7dc5a6d5326efea1471235.hip | 144 + ...49c01db99fce654e9351e711b113cf7424550a.hip | 144 + ...6f5e0b99814b0a82a731de36f28024bc317801.hip | 84 + ...801d21c14796c08377349ec86a6c800af497b7.hip | 84 + ...82d55544b5280b49b071ea277fb1827193fa2a.hip | 144 + ...9616f72bf16a060fa50091ac139ddc06bf9d88.hip | 144 + ...9f68180582384ba81aae2b1d4a4c52dde2c68c.hip | 79 + ...efa9c427dc278c0d1bc31189f683cd45e4d873.hip | 84 + ...204f6805d5d830aa6fca2a9b5f238ed63c3a73.hip | 84 + ...220f6dca850a5b5ccf1f619a267c40c37efeca.hip | 71 + ...4a9f10ebc51bde3f580ef527c17f89489c12c7.hip | 144 + ...5430cb65d8d540836c7f12b3367abd3c8e63d2.hip | 144 + ...8031345ea71cc17e458eb97a559b7c94d3ae43.hip | 144 + ...896aa9e4e4d7e494c1755b1e77a08e0e264f8d.hip | 144 + ...a44ac409e914c12281f1d26e5b52d8bfd0df75.hip | 144 + ...a9e92183ba87924e73ff0b5e25bd12d6038e69.hip | 84 + ...048a8ae1c0096f3372b0114c15edbe813425fd.hip | 144 + ...14f820b39a8ba81e547a78ed19a909ac13221c.hip | 144 + ...1da34ee666903307d3a09b7a032f2a70054759.hip | 144 + ...8b28f65f19e7d1b22fb3b85b7cf3d09cd54ebc.hip | 144 + ...9e0b97b3fece7c12504f4c8f1860d611b57269.hip | 144 + ...ab710e4acc711430745e05e036dd6a4d6bcdca.hip | 144 + ...ba7a5a0f3a714eb5f9f2af20f7bfbc82a30350.hip | 144 + ...eb2f81e73d65fddce7ff43c397da6529317607.hip | 144 + ...4d530731c7ade2c7beecfd1bbbca8583032217.hip | 144 + ...60621af3f7e1e81a8be48fea8d2750fdecbbf4.hip | 144 + ...76eb68c550b50b9aea42a7a2cc3bda186b0e40.hip | 144 + ...c411351ec59bdbed2590c599f9eddf7807b371.hip | 84 + ...f121a3c8928c10a2d86b487cd13fa995da670d.hip | 144 + ...3b3798f11997d33ccb58d90ed6c10d5411b735.hip | 144 + ...9336d59a8b35919e593217b6fd4314a04ea359.hip | 144 + ...a0ca185449a49fa485892fde6af745ba758167.hip | 144 + ...b3488ddf3bb1a4870371882f0a5d267bdfdf73.hip | 144 + ...c3c1e3dac623f07c2dc1b934ccb868cafcb38c.hip | 144 + ...cf03c0aa3f1b2a7b76b4e3418eb5063b982a29.hip | 144 + ...fe2db75cb20428856b02cd1cc8d7b393a6ad9c.hip | 144 + ...794d9c185b21f59274ac5d4db10a7abc0be968.hip | 144 + ...8552954505a2092662071401e135e84956c4c0.hip | 71 + ...910c8b7a30acc731948ab58467fdbe4fe32f6d.hip | 84 + ...1b49505cfecbe4ec3e5c7371de3aaaa85ac9d5.hip | 84 + ...1ffaf653085dd7f122d603bb3ba4b001e5f3c0.hip | 144 + ...2767e588220d0dc6137b00cc1d8dcc91e97134.hip | 144 + ...49f19deeaea20663bee781af7edced7f7a4fc0.hip | 84 + ...968bbf7e210911fcb95ba90c79837230ab1ce3.hip | 144 + ...a020f728df204ff51e37d2ddc21afb0aad5e7b.hip | 84 + ...be70b088b20fc8de464167c35745461ddab640.hip | 144 + ...f651d3415562206c1049b172261fddba01ea6c.hip | 144 + ...1828f15eec2a58be23063a1a8132d337cd26de.hip | 79 + ...67cce35ab784aa42ebcb75af7305bc38a8721a.hip | 144 + ...85dcec0197fdbb50124ab06efa627f1a2c0567.hip | 84 + ...8a4a8210a972bb2ed89d6ac754fb79438ab2da.hip | 144 + ...fb736c61088b8dd92fe0371f5c98e23bf9077f.hip | 84 + ...0e81c3700f130df142c9a37a368944ca548721.hip | 144 + ...3e8a33fdb7053760c9c135002b0a94facbe015.hip | 144 + ...7f4aaafd1a5b9ee85aadc6fab79ad0c27a2ea2.hip | 144 + ...8aaa193f332ed13e017e78ec07a7c80e45f6c5.hip | 84 + ...05ba47078abd7a5b6a51eb93b26095517e7f70.hip | 84 + ...214eb450c3b249017480efb8d092b0edad6dc3.hip | 144 + ...79ef43adffdb62100270a62706fb811963925a.hip | 144 + ...cbe8eca7e3510f5caa7f13419cfbefbf031754.hip | 144 + ...3f42d5c9ccdd3807e488b00f02bc6ab5d8d99a.hip | 144 + ...4b6226b355bf35d4d07aaef1828091f03ad2ec.hip | 84 + ...66604bb15f97a56847a7c968dbe32d247cbc13.hip | 144 + ...7b6781ffff9a42beebb4d73f0d15461ddd4479.hip | 144 + ...7eb3d86aa385f9ecffbc5ba10489e56856f918.hip | 144 + ...95543aeed81adfb6d847f78212585a36122ae3.hip | 144 + ...beb7b50ae6a1fc62535b9a1dabbde6f177a9d0.hip | 144 + ...f23d1460abfe875e71f7911697c42fef0f41c5.hip | 144 + ...f4c15a119e805e4407b184625f57966f8833d9.hip | 144 + ...0ef67ce0f178aa2863c4909f5bdd7f766c9b2f.hip | 84 + ...638314efcc4f16aa4a6e58e6caf2fda1711519.hip | 144 + ...ad2ed9f91bc1efd89ea66cd5c775fa140cf931.hip | 144 + ...fb7075345704340ff33dc0ef7c04ef127f26ad.hip | 79 + ...07bf9c05e41dcf2416e05dab4bdde17158db76.hip | 144 + ...17b92fab5bee7717bf9aff6a6bef7cee3816e7.hip | 144 + ...307974bdeeef95cca0d130ebb7aeb77fb1b6eb.hip | 144 + ...40d762ed576832b3a752453e9881b5fe6d2650.hip | 144 + ...470f5c6fb81032fcd7974180297d4bb2a8427d.hip | 144 + ...5aad18f59e47a3fa3278c7ef1a6372830c33d5.hip | 84 + ...b86621d626722434f2ae9b7b8ab435a8dd8827.hip | 144 + ...d707cf48a17d31abef94215c5720419faa0a39.hip | 144 + ...240106c771ebea461fc2a87b6da68e510aba70.hip | 84 + ...6a4475ea795935f4cbf2dc0ac156a33d754587.hip | 144 + ...7e1d245baabe2f6293e3d85318f9936b333500.hip | 144 + ...8cda718e10824956f0ee39bbb0891eafa45a7b.hip | 144 + ...ca9cd905ea8b0454cf9564643894682b08cb97.hip | 144 + ...ebd0c2fbfc85f938b10535855c388971129a28.hip | 71 + ...f5803b33d97db72eb8a8528aeb3fc956a938cc.hip | 144 + ...31b3345893eec8ed1ddf1d8de2512b46ff6187.hip | 144 + ...3d098f8bb63133924aab70d26a6ed64018c13b.hip | 144 + ...8788c537cbf6833c58a6ca15c0a36de33c9fbd.hip | 144 + ...88527a2cdb5adf51407f4661a254bb32d7de23.hip | 84 + ...a6478cc27e52fd9511fbff38369c921155cfb9.hip | 84 + ...f4605d82507fc4bd6e96095eaee5173ea41973.hip | 144 + ...f58a5186d69efd6062f3717bd315394ea6592b.hip | 144 + ...3246f1f53a988cf252eff88bdf814bd382d3ac.hip | 144 + ...586668a61ab88bc46b763df8f1c2ea52001ea0.hip | 144 + ...c8e45f6ea7cf5dba9eeadd0b19481d9f5defb7.hip | 144 + ...cf755f1485c065222be4daab84283a9c3d0eb7.hip | 144 + ...4c5369aa848021e020d874289e3ae4e0f74d77.hip | 144 + ...77f939ac3dae8749cbf4232dcf04d2cf63b48f.hip | 144 + ...a2d046629a4b65c90d0e18d061c4984062f844.hip | 144 + ...b6100efe30d836dab557ea4ac54c4b9d35c6aa.hip | 144 + ...dcbe9f481c92215f3b636bc0e86ce8f65e6472.hip | 144 + ...e3980331dc4bcec6ab6f4c345c7b5f71356979.hip | 144 + ...e5fb3544dafa9da03fd2de4bb9bd0718f6009f.hip | 71 + ...37ce5f3cf13ace3efc0b0227ae5a8c1fdfce1d.hip | 144 + ...4d1d4408196d611b2e0535bf8833652acbd6ef.hip | 79 + ...64e378e1ea1d4dd97f6949d66f3492883b663e.hip | 144 + ...abb25dba0c48b380b2dabeb6ab7efaa706d180.hip | 144 + ...09c38fc8a2d5ad6efd449107dc54a7509624fe.hip | 144 + ...44f96bed2f56793b1c2583485aa161cdf30379.hip | 144 + ...93267865f1c2b0aa1a09a586f54cec98eea4ae.hip | 71 + ...d4901b8ef034590314048de7223a572d61ee0f.hip | 144 + ...ec21ed6e040260c4f04ef68ef9307aa86985a7.hip | 144 + ...1401abfbbbdf0dd1d62df8bc3e85371ead71d6.hip | 144 + ...3176ecb1f0bc800c870861585edf56f88d7739.hip | 84 + ...4ec604c577a27e0aae5b39711a9e2eb82801b6.hip | 144 + ...5705ae121a1a331527cedfe4d31218a428a0df.hip | 84 + ...8a3d76e8ab73af9a5d2302d33e3b1d1b866dd1.hip | 144 + ...97eca4d1a18306b406b367653622a8d64095bf.hip | 144 + ...ba59d347ce8916a22b40e6f22a3c89e13db4d0.hip | 144 + ...d5f2aef029f2103bb419cc982cae99fd1a9253.hip | 84 + ...24904ac5a2040c7ea72aef5942212f291a21bf.hip | 144 + ...8b211174da0f398b2a093e7389905b4f9c4060.hip | 144 + ...96c14b8fee751d03f42ca48ea4f66e87fc2e2f.hip | 84 + ...97ce4d2e5264bdeda47487d5bdb55a014c6616.hip | 144 + ...a310a6eb86e3e8baac7a930c3ffbef372942b3.hip | 144 + ...c38912947881caa14b3fc7ab7bca317e296dc3.hip | 144 + ...f2010bf6c478d2f0eba77e912697661306c1cb.hip | 79 + ...f21e38ad01fade35b1db40adabd75eb602410c.hip | 84 + ...01e6aea44b96e94fb019501be6b102c6e6a654.hip | 144 + ...1bde840c0c8149b24a8f6f264e963c4e9e8ceb.hip | 144 + ...5940baaaa2ae6ade43ef4c94a220eaa63702b0.hip | 144 + ...674fc182dfa6329c73a354aa3adf458429444a.hip | 144 + ...704ca28a4877a1e84022e022614709adabb280.hip | 144 + ...8c80fd3ea17813df1bf19a158186834fd00780.hip | 144 + ...be322fc072ca19baa82707e260c6eba936ae19.hip | 144 + ...f884e9ca116ee47b446efe9fc770c178a858d5.hip | 144 + ...0ad1eb1b30ad8f1e7c17df486093129b2d5630.hip | 84 + ...200e875e0ef160b311c7de450c137772312d0d.hip | 144 + ...2016803aa3ca6ebe785557118365f9be7c4339.hip | 84 + ...26be8909f631c04d4395fa4ffd03a736f447f1.hip | 144 + ...28d5bec7941c9b6d5632bee8d67ed92b9c03ec.hip | 144 + ...64814a0de7702f0b7b5ce9dede6440603f4853.hip | 144 + ...a814291d8f01870274149b9d82fb75921d6e20.hip | 144 + ...d0223697ed41c4c2fd8830f8df6e5620db547f.hip | 144 + ...31ce329f2a0812ebb1dd103ea4ba8cb7ba531d.hip | 144 + ...38849e57ee9cd292e588f587a8079b57becfc8.hip | 144 + ...3ec08544591a22f59dc12f169b7327b4185a1a.hip | 144 + ...4c35fee4d372123631312f1051c43e1fa12378.hip | 144 + ...663faeb0425f45e8a0da0f7b1a5ddbee5e07e7.hip | 144 + ...72c45ba170f2782c4b5b75cfc78ac79a4cf157.hip | 144 + ...78e2a4d3b96a552e03d1ffc33debfd50c9f7f1.hip | 144 + ...e1edca5abe1bb3e7aa946eab6484b7bed806a3.hip | 144 + ...e945db4afa1330fe3978bc1bc9ae99828ae287.hip | 144 + ...f7e2a2c08cd87702793f91b6935cbe4c22be55.hip | 144 + ...7750ac0b18b48f56ceb4640256e9bd3a36621a.hip | 84 + ...93fc08ac5c6ce7a2eceb1227f4e3718dc4cf5f.hip | 144 + ...a7dce707954e765d97cb22e57d9bd6168860d9.hip | 144 + ...d0b8053ddf99a4d4447656d733c2da026b3a7c.hip | 144 + ...f182ae021e23869d7bebf2a9b4575bdc910ed0.hip | 84 + ...0ab620e6d62259a559e329460e46e6e3f7c3f9.hip | 144 + ...13d62a715fd717f0d4101f787349cb49cbe70f.hip | 144 + ...242e5953f44316b6a4f6587ec26283ed6cbcae.hip | 144 + ...2e032f6500fbc5468183415b6dd1d3e43f0bee.hip | 144 + ...890b126da2d8cfbf84f048b779cac2dd56b509.hip | 84 + ...902ed4ae3cc6558c73b730ff3949778007a230.hip | 84 + ...a14aa94d625b33df1adfa30ef4d91769592608.hip | 79 + ...b03a62e064864e1e9c1cd506c1b2e1786a777c.hip | 144 + ...df69b51f0a8cc9ae7e250e60df38758230fe4f.hip | 144 + ...fd1a756247b15b078d15a39e350a07c22982da.hip | 144 + ...2d3680c3578c7292349b58843aef7a82e0087d.hip | 84 + ...5680f97836be4a369802e8115617a83875703e.hip | 144 + ...67045d438a7e4b8f3a313a5df5a85f351c1be5.hip | 79 + ...7fa76609243a8709f349ffc0d9d88157f28dc9.hip | 79 + ...9a3bf1a9b37e0bd9bae6249609e5994dc0dba1.hip | 84 + ...b7b63e8a4c1df4eac4d978e166867195bd6e53.hip | 144 + ...19fc90e5a9c422dbf529d2def286f47dea0f50.hip | 144 + ...23dde1a386436e9864c8fa5f1706c0d2fbfd0d.hip | 144 + ...3d8ef4da515960bf40eb1feb04d21950ad5ae5.hip | 144 + ...4710e8f4e27fae4ae079f1667c3a1879cb6da8.hip | 144 + ...be4562c51d6829ec5942e11035c452fe318b3a.hip | 144 + ...dc419d4248dfdeeab1f0980aec35fa134e52e0.hip | 144 + ...08373ace7087bdaca4ce8b0bc329f553f88d77.hip | 144 + ...0f767c17385eb7d756cbe8ed444d7cef72dea5.hip | 71 + ...12e9cb599d24631c082e3cf65d2c58b6d4d44f.hip | 144 + ...2f87c021e0b6a27b2d7e30351fd50f06414b5f.hip | 144 + ...5667b27f15a06d4040354fba3601d48bb9c045.hip | 84 + ...ac5d4cf103d658e129673549549f1276f134e0.hip | 144 + ...d260849b86c46b685955cab54ba07d49b47954.hip | 144 + ...dd621da88c57798db1e689b93b692b6519ff96.hip | 144 + ...fe21ee27f8a0ca0407ef0dea73cd73ae6940db.hip | 144 + ...1bdde812c332c9fc58613698568a04771b9fa8.hip | 84 + ...332a6aeecfb12dcf70c69157fd3137343fb9f6.hip | 144 + ...6129eead18d13a4a6cb9550384fddabc7a2a16.hip | 144 + ...89f79217037e361bb0909d06534e40f5026b4f.hip | 144 + ...9519dd0d0f940fd5efd61bd32df7528ba7e3fc.hip | 144 + ...9c7feb747241c9c7de2adf3a19933a1c4c0995.hip | 144 + ...a9c37d92e344f3cc58cd4d1d00f19167e3623e.hip | 84 + ...c038393ec329a894aee9bbac078a40f57a4684.hip | 144 + ...c04763d635c5bc3e810737b5d948c59f117d5a.hip | 144 + ...e953cb24e28bcdc8f05783894b23cbf83bdf35.hip | 144 + ...6ccdb3c2d595fffd05bc5e6417b157276547fb.hip | 144 + ...80d44e82e601dc48d4c8b4e710ef7265894b6c.hip | 144 + ...9403cb91d6aabebf081afae94a8ba397d8d24f.hip | 79 + ...9bb3486fee7b7c9e24300b8a4e4ce88a11bfc0.hip | 84 + ...a76fc1b066a15b08dc6c24a7cf33a58b4cb6cb.hip | 84 + ...e409f4421193fb48a54aa5f26bd6229d23204c.hip | 84 + ...f65c7abd9b0d8a2df9302d6dc167637b3a72f0.hip | 84 + ...04763f674dfb3f14b66dfdeb2a046e413ce2cb.hip | 144 + ...07bf7ae1b71bf8ac4a793aa519ad333aa7a7ba.hip | 144 + ...21fa266c77e6b5bd1af2a9c22c686e5a6eac78.hip | 144 + ...2b21f9588d72c3c3e3b9a3b269f19c484d5aa4.hip | 144 + ...46f566fa7188c92568b277354e8b06ad382544.hip | 144 + ...6f9ab9baf631df1d3a8d801e4cf93a102526cf.hip | 144 + ...7545400aa6e70ff49a5f38ed6a218a180bd87f.hip | 144 + ...987e2d765efc320eaee813607c94c80ee35aa4.hip | 144 + ...a72d70d80b66c19e85daa00497308381050048.hip | 144 + ...bfb0e6032892cc58cef4dd403f305a5b76851b.hip | 144 + ...cf0997573f4bcfbaaf75e40f519580a7495a17.hip | 84 + ...efc341089a50ed5669b3c86f6ddd9b124d1442.hip | 144 + ...f51f0e178c33e6196df1d2e47bd38bf5391cc8.hip | 84 + ...fb694fce7b4c3c459fca43c89c6002fbfdaef5.hip | 144 + ...0dd4e870ceda3ba9b5f0084a4b025b2e609d57.hip | 144 + ...1db756577b61cde9fe8279d956980db9ee21a4.hip | 79 + ...3e60e8405aca3f7fbed19452ae37574ada9a77.hip | 84 + ...5918206483d2ae04a45aa67d69dfb986587214.hip | 144 + ...6c48e129a0235cb3a19124ddb28cce286fb368.hip | 84 + ...acf1d17650712b71a499bb66909bfcfcb6aecb.hip | 144 + ...bb8f13b6f20a72c9ce6d0b53f81eddbf05f1c6.hip | 144 + ...dd3ea61bb61de02667b14f5a94198f48c7307b.hip | 79 + ...f6c575c3fa2ccc7e65022f1ba65c8cfc16541e.hip | 84 + ...048cf91270631f98ac37dc488a1fb2e00ce004.hip | 144 + ...50f27341241086515d833aa53ae873d4ece3fa.hip | 71 + ...78845045d68027dcf3bf867ecde2fb12ec51d3.hip | 144 + ...ad0c0580516485ea432d98f53e73f6dfec548c.hip | 144 + ...c932e6eaaf44861c794539d9caf8b50192fc44.hip | 144 + ...d7f61e6313930f063758b61102e7a43b118beb.hip | 84 + ...f0f3d71108dcc49234a258f0f3b21ea2123cc0.hip | 144 + ...f1d7e1a93bf2fa80c409e6827ea88af56c44f0.hip | 144 + ...01bfc0394936a68fa0098580f06e77c88ebed9.hip | 144 + ...080406598df6bd3102db70a554e496e29db96a.hip | 144 + ...0e3532f27b391585d5de90f3bdf97992b67651.hip | 84 + ...52031044ef2e4a22e27ad04ab5d2c02121faee.hip | 144 + ...5a906031a258c6362313eec783678bd8125c91.hip | 144 + ...6a308c2d2afd6e0dfbfda61984b631c4ccffc6.hip | 144 + ...d580a612af85533c87aecdd7b0345c71b75980.hip | 144 + ...d920a76114c63156740ba5dd6f3846c4b21c28.hip | 144 + ...ddca2c6ecbba4314c434e7471ffb8fa642f936.hip | 144 + ...f6a1837a65df12b7c55d25ca28cc939c2a6328.hip | 144 + ...3e7888cba5f463d19fcb71aaaab25dc3d2c09d.hip | 144 + ...41910c34830ad2459fb85c2c14af02da718fdc.hip | 144 + ...57ea5726149efb8778e6d90798b8e48288fc9a.hip | 144 + ...7feaf237911478173377a501ee19ee325b012b.hip | 144 + ...cca7528c7d1bf49ba79625733ff0ae7522c096.hip | 144 + ...dc4af43de08130a04bfa06df9799b6e9e96900.hip | 144 + ...e8ae99e184013739019c93d07caddce532382b.hip | 144 + ...fc5e94f89d6a9287cf64662a372784511468dd.hip | 84 + ...13d96a66a4d9fb8dfc84afba7e1d8c200248a6.hip | 144 + ...156f2c556c6ef6180608c361b7b35ede71ffea.hip | 84 + ...4c8003a508ed3f8cbe6967c4ae2635a491c721.hip | 84 + ...908fe6dc9c629c82d6953081b10021e64583b1.hip | 144 + ...960fe542635079de5eca3c7785890cd4740005.hip | 144 + ...fdde4b25e2fc8cbdd46c2850c19eac8d9af8f6.hip | 144 + ...309c036d96367939ccc3e8922595ac35a3e179.hip | 144 + ...513d6e065a44bcb0c789eed1e7e5456e800ab6.hip | 144 + ...5eb90b1a2d64acc0f6fbe1d807c501fd4be3cd.hip | 144 + ...89126a7eb09d81baaf8f99dbff8932fbeab3cb.hip | 144 + ...d73393d0d8b769f30222f7817563a955c36dfc.hip | 144 + ...fa51b8c7a2f3fac5cf4cd2951ed2ede5c35450.hip | 144 + ...5b08ca602fe48840c72cd61798acb98540fcd6.hip | 144 + ...6a418fbe6183d0392b7a7d9986d067e323e2b9.hip | 144 + ...7e33463b3bf1853c6d2d2009af8d27bf88abbe.hip | 144 + ...93dc3217e154b65ebba065aa10ab4dc2374ae8.hip | 71 + ...e3a06266deda093bdf28af82d8666066157fc6.hip | 144 + ...40e8899b4e632714632450bcef001c6070f955.hip | 84 + ...ac7f6cbdfca2e397bcb86af4216e87166601c7.hip | 144 + ...c04463f9c5ce565a9daa8c22e16de80fadd707.hip | 84 + ...d52c5f70abb525b9c8aa8fc1cb3997c33ed67c.hip | 144 + ...ea5b5346c87cc4fc1e841c518080df4ab811a2.hip | 79 + ...ed7f650c958a644c8031aeb88688b1e42458e5.hip | 144 + ...0aa875ac13957f00b30210477924697abf0c9e.hip | 84 + ...617bdea526d12d6a33ed42b9b0018c0b173722.hip | 144 + ...a3327da9a3411ff1cddc67eb647083cd947a92.hip | 144 + ...1fd28acfe85b3adac859c4bbffa4d28fe634fe.hip | 144 + ...58d4bca33c4c0e79141a56688049237d170d1b.hip | 144 + ...824621a50cdc3cbadc4b1f9ef18e1325385082.hip | 144 + ...980749c6b2a18c80426dd189e5506334343ca4.hip | 84 + ...dbdcd28cb2f078f89adf9aad2b3d4a0a477823.hip | 144 + ...17c082f249649eca733a8f0cdf9a1205c3e3d7.hip | 144 + ...9043572cabb65435627a3faf23b18d039bbcd8.hip | 144 + ...92990df507e82f96eeb7aa3ec00c01437566fb.hip | 144 + ...d1a40b12ce927323594fcce61eb9c20cc5e3d4.hip | 79 + ...d7b8c63a51c8639b3cf27ad09d41ae47c480d3.hip | 144 + ...074afcf33e3f3534ac3577484237fcfd2ca48e.hip | 144 + ...13c4f3f645a2bb475eb1c55ce1de452f0e2332.hip | 79 + ...3bd4e029bba76ebfc79e6522dbc8ca0bba5dd2.hip | 144 + ...4688cbd23727dd0ea9a36fb977b31aeae98d65.hip | 144 + ...7970957024de050748d3e31cef434f582d968b.hip | 144 + ...dcdeb845e7bcdb89ef70ab2a97157d4db3cb52.hip | 144 + ...f1007430da272174d3476d042f398627e83512.hip | 84 + ...079c1eb36db8461fa8b861c56760afcd97cc34.hip | 144 + ...7549e66ef309e32779ddc2a1f14e79bae53754.hip | 144 + ...79fe8a600c3b4e0ec9aa510f8036ba2b608985.hip | 144 + ...a8285bd6182355e3164cdc5a983375cdf0a61d.hip | 144 + ...1b48a28b71c7f4c78eb14321b39951a7c5e903.hip | 144 + ...2c587db8bd9f1b551624e0cf8b67a90245d7da.hip | 79 + ...2d5f979fc4fbd0991581a020a414f9c8656ae2.hip | 144 + ...431313fe082958d31b68d2fd0d61df0fe56736.hip | 144 + ...50ea8dd480012cbe10be392cd26d1870e6ef9b.hip | 84 + ...675919a6c7758cbbeecb83b7ac6c62f95cdb46.hip | 71 + ...812705ae3e452810794fa7caceef2ef6066dfb.hip | 144 + ...816fcad5e9ecfca94a6491eb2274bcc41e558b.hip | 144 + ...938d0e3ad30db201880642e57758285b2ec4cb.hip | 71 + ...fb5fc2ace6839eac741c5e6616665845f43566.hip | 144 + ...607ee20c0d92b6dbd0338f139517fdcce98d0c.hip | 144 + ...6e463eedd3e65b9c79feed3cd92ad8cbc9f036.hip | 144 + ...7166d4bb0c1c9b9999ba16a1adbf09ebfdb6f1.hip | 84 + ...a4c40e244b412a07933d369704bcdaa6d5e74c.hip | 84 + ...b224b40a7be7db0a9c5c08cc5ab05b526c14e8.hip | 144 + ...b33fc20f2e85e915f1b1529ae87981dfcaf86d.hip | 144 + ...c08b4f3959a2375ac03f40c4ce12d70cdc2d80.hip | 144 + ...09b7d39346537aa6c4a4e46b81139f603edb60.hip | 144 + ...0d7f81c73b35ea64095d01c5d48d9190839e0a.hip | 84 + ...68ba8df8b0e977e9769f6acf6cfee6b00b9922.hip | 144 + ...6fa8bf5e992ddc25815486ae9c24d8bfba7227.hip | 144 + ...b17d8cba28cceddb3ef907df878aeef0762d15.hip | 144 + ...da0d469cca5c8481504148468460c85a15c559.hip | 144 + ...e5c56e92712d00092ba102a5eb5176a3e5d471.hip | 144 + ...0cb8bd09d287a1566265eb1e8894fe68d3cc81.hip | 144 + ...5b75db795dbef037b14b003ee073665fe35d3e.hip | 144 + ...63ae070075f26926a86d39e15c27e6edb1f1cf.hip | 144 + ...695dea4171747fb3cc6d910459f800608d07c1.hip | 144 + ...9ae177b7a793fa352c4f6bb8e4175f3064d814.hip | 144 + ...a6200e36944b1f11106c02f7fcee053f01ee71.hip | 144 + ...b9e2616c2fe0480096b1ccf0f74d584b220146.hip | 79 + ...c916e14198f6d18dc89915e379b01070434e91.hip | 144 + ...07a63fc55c411c73e4f93306c5ffed800dd249.hip | 84 + ...121fd448b4640a17e1a7fe73bb7b58714c0afb.hip | 144 + ...1f789d619db6f225e8e9d646e93bbc9dc1a669.hip | 144 + ...739f4464512feee083b875e11e11eee4f5b448.hip | 84 + ...992be6252f2afdc368bd4baec4b8a55ae0abf8.hip | 84 + ...b0770fe64e3c60b9e56170aa88bbf74802a813.hip | 144 + ...b722cdabcfaa388ccc6ccceb7e42462f3bdcd1.hip | 84 + ...ba64cdf615c1be2865f027a293cb530fc07dc6.hip | 144 + ...d841e6d783bb46d841aafd9027f92dd1b61b88.hip | 144 + ...e53359c69bbe4d7405d45261a8a62008eb7d06.hip | 144 + ...f9ad0fb65638cfffb3e7786f2cbf01d9585b23.hip | 144 + ...054acb8a9508fd0f0f486367fb62454de47c39.hip | 84 + ...1cf8d05cfa45319f4e5bb49334d35a530bffcf.hip | 144 + ...728d999ae43ee1b5a16e60b90cf8533c7d303f.hip | 84 + ...7801fbb43fb6797f0425f08d13926b74d87c4a.hip | 71 + ...7c48d0b7096ad6c8bc445f13f2c8c1934695ab.hip | 84 + ...b885d6869400b0dc2ef1b2c2636ddfd21cde31.hip | 71 + ...2439e4f5644a3a4630481bc7d98834b29b6e1c.hip | 84 + ...a94d145e575747c8956ac703810582c819e2e8.hip | 144 + ...aa519eb57e5797125728492d9330f5c0f0670a.hip | 144 + ...f6f9dee9f0c3825d91f4d320a5280070e60ee7.hip | 144 + ...061acc6650fc7b79fa1fe5b2b1e083555eec2c.hip | 84 + ...1343832a5bfd060c8d12da0d8a090f070a717d.hip | 84 + ...45f95c1093c60f0fb6c794636f79aaeb53b733.hip | 79 + ...530399ad7b43d8ce2c89da24c71056f2146b18.hip | 144 + ...83148fd684a7e6a312127e023798278415bd27.hip | 144 + ...94816877815bc0294610ca24f986fdccdc7c6f.hip | 144 + ...0ecb3013071fb65f2d5ed4c947c4bf303e5308.hip | 144 + ...38c9618dbf2af119e37596f7eb0fd3f8d72748.hip | 144 + ...3986150adcd6e1d3886bacf2166de1252e14df.hip | 79 + ...4f916d3484295b5918e2e4c22c5529588a5662.hip | 144 + ...89ecd7bf51bcffe9f5002959bdda41c50a3c8b.hip | 79 + ...8fc75a7d102aca068e3ceb6111728c280fa837.hip | 144 + ...c129dd4c798343d6f78ab78056f0faf2f1c9d3.hip | 144 + ...c5e79f54b71677124f555b0ae4bfd27248d099.hip | 144 + ...caa2056d99eb67ada498e287b4fae984397691.hip | 144 + ...dee49ec6755006d67f0c30c65f50558bba69b0.hip | 144 + ...f1bb85dff8c97846f6b2e8796a6289bcd0d9d3.hip | 144 + ...0073c70133ff2ee4737f803a0ac43801c47242.hip | 144 + ...1a08c2e48d805b295d979b24173a04cf58def0.hip | 84 + ...246460c21bc66c0f13936d27477a9fca1c44d1.hip | 144 + ...45b04a8026a01828c5dd606d89d044d3ed1d99.hip | 79 + ...6cf509d9c2bf86ba6ee5ded544fa8e6717f590.hip | 84 + ...7137b371df841993c8d0584be7d83aca6add78.hip | 144 + ...851d5ecbf02f8af623988b1a39c0b91e51533a.hip | 144 + ...01b25e0f132d647934deb395b62a3f70cc7c88.hip | 144 + ...7a617fae00fa90a1ba60937b0312c81087c19e.hip | 144 + ...7f00dd759d9714693e7517dfaa8bb427294d42.hip | 144 + ...93336a4b00b2a63f23ed7e13ec54c82d9e5063.hip | 79 + ...e484adeddf3394d8d7693b808d83b64c71ee69.hip | 144 + ...f5efcd500ce6b9ffc14bc9877e0ba457539925.hip | 84 + ...f9a4f4d85f292b78123599a2e1798f12aa545b.hip | 84 + ...90e6ad243a48b84304b5cad0c663c0802aedfd.hip | 144 + ...ae680eed89ea93a3a94586bd5a68dbc5439f37.hip | 84 + ...e2f290b962f1617b0a9d4fd6d55c43e4439d6f.hip | 144 + ...f8352674bd6bbe98944a1c0a769a4fc028a623.hip | 144 + ...0a70932bd587759df1e5e150b25b0126d7b529.hip | 84 + ...20fa19d8d30654602e363806f559113218d66d.hip | 144 + ...8e04fe9432a60f86ff0369e8c1851821074a04.hip | 144 + ...9edbe35a8fac7796f00bde836bd547044770ea.hip | 144 + ...b73ea77ec20ea3bfaf995dacf93a6960ecdca0.hip | 144 + ...d1f99284aafc8d7908d062f179a056eb314925.hip | 144 + ...e866c7db36286876818bfb718ac35204fa3843.hip | 144 + ...fe4b6f3b901ff4af81bd4f1cd8ff19f09d0b07.hip | 144 + ...062dd633645772e4f2caffd111af73184f7657.hip | 144 + ...327f0fa1155f2235d76be45cd22e3db5a69429.hip | 144 + ...4dcde1ae3446b825dea739d4295c1d1ec5c4be.hip | 144 + ...6d08e63b9a90f2524cbfa8c5fcf8b82a1d2d36.hip | 84 + ...73c92a13757877f34bd8a13c6fb29b60999020.hip | 144 + ...841b7cf5da31f0c30ec42c91cc8d5bd3fedd03.hip | 144 + ...cc791049e3ff9ebc1a9085d2d20efcc2f99b71.hip | 144 + ...f235679af1ca03a6e601b4cf6cd0416d1c9091.hip | 144 + ...4fc7cda4b560040cec93f63021b529aa1ee3fd.hip | 144 + ...a3b1d36d777213eb381b47871bf15dd163c994.hip | 144 + ...c3ef3d3b36f52089548e9dce522b0448e2c26a.hip | 144 + ...3d274058bc0a3d4d35d90669587761fdfbdba1.hip | 84 + ...6759d8855c4c6289f1f241a1628cf0406c1b64.hip | 84 + ...69d441f48f9ea346dd8e00376a9a708da3ad87.hip | 84 + ...c424f0e192155e3c4e786e5b87d5a1a3e6c4ad.hip | 144 + ...51083e13aa4dfa8c969f8f916835a8e5e9ca39.hip | 84 + ...ef1b54d5d3841f3fa6b84cca6c7ad33efa2d9f.hip | 144 + ...0517550c7a23882b95de451e8099ea2186b4ce.hip | 84 + ...b389d4b5ba590baa951f17da06f0e53d2bfa55.hip | 84 + ...17be7b8bcf303b30a147f41346898acc5fab7d.hip | 144 + ...2a71fdd587e47ee68e0cc76c3c4494ce06c359.hip | 144 + ...2f152e9184af0b3d77082d8bdf519dbbfceb2d.hip | 144 + ...46e888e3836b0bd3c49fec8e1872e880798f0c.hip | 144 + ...874fc5ac87a1ec487c7722bf3b1bdaa924ee09.hip | 144 + ...94599fb5caf5e7aba728cd4713a8d0c6368a46.hip | 144 + ...a556c9358ddd6db719458c81d2d6d822a895da.hip | 84 + ...03cd47156a98ad2cf2c325ea00df3f1d67fb72.hip | 144 + ...89292c81a18d21a2921ce6740f81ebf4c046ad.hip | 144 + ...c71e7d33f0597fe090a3524e33e18b2e562680.hip | 144 + ...cba1509c413c870c5d784410855ee1bd737da2.hip | 84 + ...d6ad9de7ac7993ae1923a2ef070b7dacb8c563.hip | 144 + ...0c91b2f11bb7e5058ca7935b0bda4f5558a9dc.hip | 144 + ...1f3637624762547af1292e1b85e640b1d329dc.hip | 144 + ...25c4f1f3c7b271957768bb9235131c67afb48a.hip | 144 + ...482a64659c838f3da55f56e3cbbee1dbfe6722.hip | 84 + ...5e2aed617e1ff31f93ae7e054313ee0dceee97.hip | 144 + ...a715b7e9c1a576f011dfe5769c5b392e984f82.hip | 144 + ...ef5d30a2318ae06430d17f84878800c4ca7364.hip | 144 + ...339150d8bf9d073827738527f6cbe15b854607.hip | 144 + ...709e4fc53d2254a03ea7660b8c72d2f47cf1ad.hip | 144 + ...88a284f45f711d82a6ed87036d87cef1872eb1.hip | 144 + ...ac4f93722dc314086f1b7d7b8adc687cd75f82.hip | 79 + ...d7aa46528ee74e2bef1e87c1feceacfa55e173.hip | 144 + ...dc780b17152f696f9b957432c2eae8fb16e85e.hip | 144 + ...f9c236d24b30bc9c3fad90cfd6eb00da835de2.hip | 144 + ...ff8445ba691807caadd9f26e7eb90851875280.hip | 144 + ...21c2ed6b295c458071f1988b9d6f7b46e8992c.hip | 144 + ...700d87a19a173e84d64e43cffabbed52366e35.hip | 144 + ...87f617c4b84c6a0328fedac750d41dc3dafe27.hip | 144 + ...8843d844f78690c7a45b730652f0f763c595c7.hip | 144 + ...980becb0d3149fee575bad1fc3b463d08aabf5.hip | 144 + ...b7f10440331a8a88ff93ba253217c2832bcf9e.hip | 144 + ...5b47aafc4340e69e300ac61a7601a5c14513b7.hip | 84 + ...5c7dd576e5b1061c059e5e99aeedf4389e2d25.hip | 144 + ...9423c095db052603d77073d409534bceef425f.hip | 84 + ...a7833f4597bb03a3e845d5580d677e97421040.hip | 84 + ...bdc110955c05c6c6ea236a6f60266a4a6dce5e.hip | 144 + ...c0109313de1f6245d2a80f8539485b849e9d55.hip | 144 + ...c4dc0d70c547dbbfb661e879ba7f9adfafc2ea.hip | 84 + ...d4eb673bafd81e3a0ee213da4603d88b8460ec.hip | 144 + ...e5cae764142683b70d3344cf07dd1edb7d69e2.hip | 144 + ...f2f0cef657ae5e333d65ae4ab20529a43cd7de.hip | 144 + ...f8b7b2a891aa9f2ab49762eb31d835efdf18b6.hip | 84 + ...fa94bb32a80e81886b711ebfcf2df5f5405866.hip | 144 + ...22fa57764ec746e02f6d4bd4846b48c722b807.hip | 144 + ...2a2ab489839ea1a1bfd1b24e54a3c232ed934f.hip | 144 + ...461d72fb6ba50e81de3f661528c96dcfdc3f3c.hip | 84 + ...4b4cf3f6706e4b4e0af4402e2263b9a1585f9b.hip | 84 + ...5c43b870705c780d734f9ef063f55cf8b3b52d.hip | 144 + ...73f35edd69241c6b921d6712dfd064d78ecbad.hip | 144 + ...1305f191f06cd53b7563971c706e8b71b19e2f.hip | 144 + ...4b0e7dd816ad08eec5a1bba6e227afee9813ec.hip | 84 + ...784b03ad757d51c234fa86ea9891f055ecd5c1.hip | 79 + ...8fecb9725ceb4bcf2aa037d43bc43efeb1c3fd.hip | 144 + ...f7553a7d2f6d42fe695cdc64423c85223af440.hip | 144 + ...21661d8280c6e9d27f2c9ce1b3c855387b5a76.hip | 144 + ...5d35b2fd98742427930eb536e346ffb005edd8.hip | 144 + ...a4af070ee46d802cb11086b93daf91538f8a04.hip | 84 + ...a744edfa3a19d1493611df5bd0d4d59b707d43.hip | 84 + ...2b43d374642df991edef1f6036dc898bf77cf8.hip | 144 + ...3324ccf11b273ed20fd960c61df897c8890b1d.hip | 144 + ...3a03b33305b33055273711ab31a5b8d8298d5d.hip | 144 + ...68df29f5ae1463706b7981b3bde55918e1aa65.hip | 84 + ...8925d99dc484da41dd55700e151cf545cf821d.hip | 84 + ...b50c6ebb27986ce5b378d8c39315eb9cb91dea.hip | 144 + ...d2be18e2d53a5144f97dfdebb225fcb6d611d3.hip | 84 + ...df9ac4ee78e5f4d5bd0567e58a7090907c61e1.hip | 144 + ...f00f270680de81df7737e848e0408cb070e68b.hip | 71 + ...1041530f794c7b8dc4a8321ea0fcdd338fff35.hip | 144 + ...522b43c5e5ea69bcabb4c0fe28def2bd081a12.hip | 84 + ...6d13b09f85ee62bb5018608812181fb43afc86.hip | 84 + ...82d20635e592edbf00439294835f6f39ad54a3.hip | 84 + ...996b9c843200a2ec33ed4319b48106cd7c6384.hip | 144 + ...fe891dad43815e635f81225705ff944f990d75.hip | 144 + ...09941bddfa9d61985b55f9b6bf0edec9bb89f6.hip | 144 + ...0be5a2072b5e87f5ee58149688796b6513219f.hip | 144 + ...0c3fe9529e24327686070731d0ac3ada76245e.hip | 144 + ...1ca4ce061f7f69a250356f613cab00d1e2ac71.hip | 144 + ...1d7f93427095e39bfc1d986b3d7fe54073ec75.hip | 144 + ...43f4a56c166dad0113f51b337a083f4df7cdb6.hip | 84 + ...56e886d53a1d88fada0f10f00b9f398dc54568.hip | 144 + ...6cd5c9242f8278c8f3d9ce57b97d605c7e5a3e.hip | 79 + ...877ae2a1aab04498bf2b26b3fe99d6488ef151.hip | 144 + ...f6c6412f9853855b74a96e862935ddef66f763.hip | 144 + ...f92a5314fd33491b5eb6ebd2418b7e0d5db774.hip | 144 + ...1ccde31b47e0e56ee0daab6403fed7895208c7.hip | 84 + ...5e9aee85cd16903bf7b82a4ac10402b0b26e22.hip | 144 + ...9382cf8bb56ffd962c99329bf67da992f8810d.hip | 144 + ...eb0641213e9a45ba48bcf72bb23845720d8b79.hip | 144 + ...091c69d19b27f7ad50ef6311532ad8b642a9c6.hip | 144 + ...82071cc074fd30437f6158b5eb2c6df1f8c587.hip | 144 + ...989d2ce769f20e175fa88f4082c1c25fe03062.hip | 144 + ...9b99a194b59d3149842c15733394da275b12c0.hip | 144 + ...a016be2bd0e377fbe01fa7adb9bbb8febce100.hip | 144 + ...ae2d4f8b2dac799e03ea6f279e6ecdf66f5381.hip | 144 + ...aef10ff2c5d89530310bdf1d53a194f06a94ef.hip | 144 + ...d29e3e9828911a117dccaa5650e77805730d14.hip | 144 + ...da7ad787524e3e47dcc1b65c41b2faea38f55f.hip | 144 + ...db6a14043c5a4df0f5042b3770b40c4e90795c.hip | 144 + ...f160741a4f751d2f15d6eb23d4121cdca62b55.hip | 84 + ...1ab1f4bbe86bb9bbc22e4774648076c321136f.hip | 84 + ...1afeb6cfdf860ff08e4c2f11c922fd5bfa621a.hip | 84 + ...239476d61f48379754b97f29d7a285cc3192de.hip | 144 + ...4e7253ad4873576052ec0a9400597bb7975753.hip | 144 + ...4e80cb185759dd9b3eb3c67c239964b3694caa.hip | 144 + ...51b30c7e1cd30e550187458350c8db7c59a9ef.hip | 144 + ...7899b1ef159ecbf01f27014601eb79b31b49b3.hip | 144 + ...87b1d5c50606430b544ed650d87df24366e7d5.hip | 144 + ...8d0bdde763e617beafc0365ec4a3cd11df6c55.hip | 144 + ...bb2441e6cc1ccba4a391566e547402bcf7ced2.hip | 144 + ...bd5fed34ebceb879ae3dffaf58c7c04ab5fe80.hip | 144 + ...bff7e6605b273bad844b8f70ef031625bff48e.hip | 144 + ...c87e65afa93e84d7a947c52f291c1c7360033c.hip | 144 + ...ce14f7a220222eb4ce6783ec2b9fce6fde94b8.hip | 84 + ...06c0dae15684f83e15722a4c07342af9ea011c.hip | 144 + ...6ccfa11add1ae49888337e84d9c446d2f67da4.hip | 144 + ...adc4f76e237514db0bc0203102297b79730bd0.hip | 144 + ...c4b47a6fa62a4ca5cff6a7e01c9f6b371d2215.hip | 79 + ...cafd07c1f56e74373ccf37db35976023456d50.hip | 84 + ...ccf699f593c828e11efc053b144044e45b32d6.hip | 144 + ...da8f46b5ded4c2aa9d722fec17b75004b59f7d.hip | 144 + ...dab954fd111ec48721f25710d61c0c8affd8db.hip | 144 + ...0e062055933388e37525df5766f3c14cd3538a.hip | 144 + ...1dc872c24db4db0c9179fc07e17f41060390de.hip | 144 + ...3ab68e33844f97aa58d463e00037bc11c50da0.hip | 144 + ...4f14f829eff73afaa57a875f74ebd1e6860979.hip | 144 + ...544a38dfdf4d81dc95894387845f48435e299a.hip | 144 + ...dd965d5d9080ed5c6a04b7eea9890f3a264f20.hip | 144 + ...f555b74ed36f1bef8f47880b3edc6760f27788.hip | 144 + ...766695dbb790bd614b83dc7569ad449404cc89.hip | 144 + ...8a615e66d7cd739ce35412811359a03cb23a8e.hip | 144 + ...92c55f002d8540d5f965cc4df0c2e33f4b9ff9.hip | 71 + ...9f05f6848403480ba41d37cdbf44ccca1b1f8d.hip | 144 + ...ad101ce91348266d3885afdf2996a0fdb72135.hip | 84 + ...c5d55d47d6038e9162d32ac968ff58c0942938.hip | 144 + ...0c6252863a73341b0010191fad4c834860f884.hip | 144 + ...0e314642cf565e4f32bceffdb5c0e653ab627b.hip | 84 + ...4f91dec2029b25d0d96962528410df55a468ed.hip | 144 + ...85e2f1970b78e18002464eeda63798229bbc3a.hip | 84 + ...98e213f927b518c693660110f08bdd94990ef0.hip | 144 + ...af5f5b5ee3ae964824a3e9c7bbeb5bb39c557c.hip | 144 + ...f91e937b427ecc932c0cb0c90b2c2378db0be6.hip | 144 + ...063d06723ac70c5f8802ab49c5c35e1debf56e.hip | 144 + ...1f56244076c501cb09b4b90975132cae4c4386.hip | 144 + ...486244e0b7d6dbcaa1951e8b8883ce441c3f99.hip | 144 + ...4c1ce348c3d9cdf6bbec9758de9d5fe94c43fc.hip | 144 + ...8a1d3cffae01332a3a9d9472ff1b2c443e82af.hip | 84 + ...a104733f678193068d8642d6560faa03897258.hip | 84 + ...da22d3482738a8474ae15e8e5fca9020c4e195.hip | 144 + ...1735d250b5a16967281a5f07873b9cde3df4d6.hip | 144 + ...1a30092e8138877c1f6c25656e0f8ae2c2444e.hip | 144 + ...1ea5293bc1c56efa2c4b5681d965aa6f2ce6c3.hip | 84 + ...588379eaa268d79fe8f8e4457b009f204a5fb7.hip | 71 + ...93c99888d82cd2852bfb101f99a2e6a27665b8.hip | 84 + ...a5715b550f67b8870ba66e1e6282a26cc1dbf3.hip | 84 + ...b037a2e262d11d3ed7d9feeb41b9e05427a739.hip | 144 + ...bd2d206ceb237ed2c51f58abb5cbf96e39d07b.hip | 144 + ...ec377c44ac18527ca6a01bc3b146706a6e1e09.hip | 144 + ...f12f10d7b968e0d8e7c23f36d3a360de74a905.hip | 84 + ...0e6df20a2426abd3d2ff2262a37c009196024c.hip | 144 + ...13834918d5ea789e2db21abece7c2d3532a7e7.hip | 144 + ...248f443a12d96815c04409a00102923c717023.hip | 84 + ...371415448fffffd58bf014dac9f4876153657b.hip | 144 + ...ac596c636df55e81293228cbc53dcbb3024e5a.hip | 144 + ...ba2e73df35f6e0f7317303823fde92a42b1a35.hip | 144 + ...bccc85f74f54a2ceb17fe3040b04fe306c53f9.hip | 144 + ...c3131fb8e5a25bd4a14bc9075eb6fa01b61d02.hip | 84 + ...c7fca1f76a31b0390e92d90d569fab94d4f783.hip | 144 + ...db3d5b1d8af89381fc4b8073f84c5fa25fdef5.hip | 79 + ...0a4e87a7aabfe3c1ce02b408522f3ec862e3d7.hip | 144 + ...b17ae67adee9e56a022cd2a5514fb9c4e99920.hip | 144 + ...2a804bb3c99830653d41ac0bd49943c801b89a.hip | 144 + ...37410b404a51043fc3bd503c0b107c297e4c9f.hip | 84 + ...5843bb13058ffe29251e053800c509c7590544.hip | 144 + ...74450ebadaacf23e944aaf8ca90eada01e8a5a.hip | 84 + ...79cc0b0380e1e6a2b51fc6216fdd72215b882b.hip | 144 + ...a03ab0b7887cc7ed0cb40e56360a8d36c0bb8e.hip | 144 + ...0d0828ba6d24ea3c1a97bd9835ee937b4b32fb.hip | 144 + ...72f9e6ebe330cc1818ea82b53acec79a2f672c.hip | 144 + ...fbc6f6e9c515edce3c7a438b3bc308b30d3857.hip | 79 + ...385db12001110c42eff6aabad935a69ad3afe2.hip | 144 + ...559dd36a0a4f5e068a722e285f485137bd5ef0.hip | 144 + ...627f9c8d0088df0364a64643f2b5dcd951f2bb.hip | 84 + ...a742ceeb6736a2c8f9439d0b05e10d3e0c5c6f.hip | 144 + ...baf70220079e6d4e87eb01a7259923d8a01e29.hip | 144 + ...d00ab8373747a5c6b9d2f8dd50ceb14db4163c.hip | 144 + ...ed0a64deb55616646ea98b21a891c971cd98ad.hip | 144 + ...145535e53899fe127987aa854f81234a9c51c4.hip | 144 + ...8b09f0aaa40a7c9ad5f0458b460d3e328f3c74.hip | 144 + ...fbef3f13d429ec3e9f4672218998d5669d79f2.hip | 84 + ...111b7acc269f8d5e70915d3efde4c425aa5f5c.hip | 84 + ...28a4e95723e3df380f98b5ac107c4df353850b.hip | 84 + ...35c86443cc9ea38c06ebc0656306483c95ef67.hip | 144 + ...a10ecb79ede07324e1198a71a95ff26e9eb235.hip | 144 + ...e23201fbebed25781f249e5c77c31e0e7f9ddb.hip | 84 + ...fd025488e52b97c04995c4c5faff371b77e4d6.hip | 144 + ...1ae1dddb8cc5d78196da6b26ebe66c1ce7e567.hip | 144 + ...238fd2095b26a167b41cdec8280182330b7b25.hip | 84 + ...4425e30a0b17e8b31726817e8d3177b5c51934.hip | 84 + ...4e0f0496a34d2fb43c80ce0162ad4183f29064.hip | 84 + ...6ce17223d8d83a64b8c96ac88223e4441a4692.hip | 144 + ...744db85d4237ee9640f1658e0caab7648e3bb6.hip | 144 + ...79e255d25744725e2a9db9f90d5cc2b8a0e0c1.hip | 144 + ...897852a4ca992961843144f4ec4f8b86dd5e9d.hip | 84 + ...b6f0730fd09b4c6c60913425927dfdb8f83d82.hip | 144 + ...d7ccdceb7baf3b986f2a0248827822a5f72e47.hip | 144 + ...f8836c8cf932cc2748e313885003f0e11a887f.hip | 144 + ...064e302ff5b983dbdb4ccf51383fb29ddff44f.hip | 144 + ...28203f47b6a48e9b66302cf8312f3796ca500c.hip | 144 + ...37f4f7914805a97d5073f1ebf8a8b8c2648d31.hip | 144 + ...3daa5f99b4522d932334924347353ce2854821.hip | 144 + ...6aa39d0ae3c87d011610cdb5e2e317f337c454.hip | 84 + ...80a1774d8b7d8bee4e8663392b97cda11dcbf5.hip | 144 + ...8bf7c572c1984ca3061062cf3c31d993f6762d.hip | 144 + ...9c47f3305e47db6ab6bc627fb3d80269633074.hip | 144 + ...ab172627718278a71a93e3737ef08ad9259a4f.hip | 144 + ...e24a8dbe6add6f2dd2beb48b1280f3a84a9b2a.hip | 144 + ...1e1533fc37b41838bd37edc2b6d2f2e76ae1c6.hip | 144 + ...4dd90ccb2f258029d0156cf23f940b694cf08d.hip | 144 + ...8ec1163a01b9cd9a802d8b44669e8770c20234.hip | 144 + ...ae876d6da465687f162136231f15767cc7bb14.hip | 84 + ...b9afccc15de7dfcb2e7d898abc0d61201de73e.hip | 144 + ...c30e7107c5dce3fe6aa87d83ed96da75478da0.hip | 79 + ...c9e4c0317e8d351f60258ed6611fbf365c4024.hip | 144 + ...cc2a4d7ac045365300bf8bd45fc6d3e1e1c8b1.hip | 144 + ...d5a8c5cf683f6dfaefad72c2e2f5c2f2b2732f.hip | 144 + ...f3bd014a918feddadc98eed92a7734f9bcd890.hip | 84 + ...9cdf86a7944cd690b0fcbbaec235863acd10bb.hip | 144 + ...338fbc05f86270ded7df2bd3e2758a03961b62.hip | 144 + ...342686e4efd26413c6719782ed13603479c4e0.hip | 144 + ...63318cb851ccaa923be12d34c84d839bc64bb8.hip | 84 + ...8095341ca7e3a1debeb780c1878e351692bee2.hip | 144 + ...a3c4ac0a50bb9b7ad764929dbee98c856b1210.hip | 144 + ...f76aff077c28f8afd7b22f284cf2894e08a043.hip | 79 + ...12c01d201c366bdd7acccf2e1b18b00f671153.hip | 144 + ...1d68fe766fc753c657362673704005b538660b.hip | 144 + ...37c03bf161b2ec6a9a046fa49d7bbf80ae47b8.hip | 144 + ...97d1f050f42d82e6851fa286db6f81ba197f40.hip | 144 + ...b76bc7a17f573c0d52c07ae9ff4302662ae61f.hip | 144 + ...b94e19d762ddc33cc4e94c6675d93cbde21e3d.hip | 84 + ...f40c3421b9ad8cf43940530ec50bcf620058f2.hip | 144 + ...f721a330b2d0fac13b22061616d7b10c0f91e9.hip | 144 + ...50ea59ab6e1ee39cce15cbd3f181047cdee31a.hip | 84 + ...541b6b5cf27de3f45f60671d36602f07ce1783.hip | 144 + ...7b3026f1dc3056dee3a3e64bf31c45683607c9.hip | 144 + ...8de8f96c8315877031a2d56261e95fee6aef44.hip | 84 + ...9110dd501853e87ebc122dd1971b0bb1bcd92f.hip | 144 + ...940fd05efd52bdf8a3f9aa4b78bde9b5809b34.hip | 144 + ...a2856bf9a81544a30d535a13554e3a8107c476.hip | 144 + ...b719893a4d8a1e71857966d399f06c0a41749c.hip | 84 + ...f04447e6a94c94a2315454e71d7d607a9fd0f8.hip | 144 + ...fcced07cc194a8050bc7b2f791453b3f5b2064.hip | 84 + ...23a4d1f24d59bddd20ed2f2fb6446627b0ae8b.hip | 144 + ...55189ade9b1a8269230232db754a3881b53168.hip | 84 + ...5ea54eb6cd0f3756c462c66d9be956279b46ad.hip | 144 + ...63ee1b087f6b504a3dd3972b96e77db02b0582.hip | 144 + ...cfaf0d53869c373f6d0ec821b008dbb819141a.hip | 144 + ...d0eaf9399c863d672e8c08d123739bab837d4b.hip | 84 + ...015f0d0a7a5173810f6f17c00065e03fc61a89.hip | 144 + ...02e84359b2037a29efd1d6ce7213ba7605ab25.hip | 84 + ...1b6eda4f250da059fe0c428428219ff5a250ef.hip | 79 + ...2ab428503e8f8bfa78c8cb8d9afad9f5185118.hip | 144 + ...376ac8d82db1bc25fa273a80dfbf8b71ee5e2b.hip | 144 + ...5a5e40f6a66bc5292a56e0097c69fe37cedfb3.hip | 79 + ...87a1a9933239270f44b1e08e1cf5323521c089.hip | 144 + ...997f79435cf64add10506acb97d0647cfbb3d4.hip | 84 + ...b34d3cb673447773f6da23e9cf52b98e99f718.hip | 144 + ...c3425fe683d35dc3335db77d183ad1620b7a92.hip | 144 + ...c6c405cefe204824e8fad1b3dd34bba87e796a.hip | 144 + ...de1bc135191f3c2aff740f4c6bb7e98da42f84.hip | 144 + ...dec99707511cebd9188d216ee0a148d729b470.hip | 144 + ...38dc4f65d02776875627cbd20a9c794d70b043.hip | 144 + ...3e295b68e807774ed31bb914e4bc59312a77d7.hip | 144 + ...6aa150611b0d4800470c1493dc907082a5c23f.hip | 144 + ...81974c8b6f43f60d0af29c350d850b55c03121.hip | 79 + ...9937be2b9a13d6520fdcc922e4e75c9fa085ab.hip | 144 + ...9a22c6efd8bb8815887325aa0b739e260cc754.hip | 144 + ...9ab718fa23f24f09a713ac28a339208a7a5802.hip | 84 + ...b440ca9a5196ee1e72c878c87d96934e9273c8.hip | 144 + ...fcdea177734366d3bf283317a65cc3fffda611.hip | 144 + ...fef330a975002ed15670e8e7b26a10376d3cb7.hip | 144 + ...4f4cdce32189065362a502105c31bd2d9d99a4.hip | 144 + ...e2da8b791d31f4ba05ef5f833fd6dea9e35f1c.hip | 79 + ...568e11e44ce70924d27e683190422cfae5c31d.hip | 84 + ...af2bbfac25de2853be344b9f636226c1c0112d.hip | 144 + ...06d7803d06ef8aac1d5caac9f36aafd47653d5.hip | 144 + ...0dce1a17d073259250ec0c87ade69e639ffa8e.hip | 144 + ...dbfaffc8a9b573f194f9c63f1175d9725f8950.hip | 144 + ...f6461673882d636772ae4d26e78eabcb568f31.hip | 144 + ...19b8ed877d4244d01a17ecb948b459e361ff24.hip | 144 + ...21a4790f982d48bcaf950123c699647afb739b.hip | 144 + ...312d7159369d13f3148a6f0882dfad6921ceec.hip | 144 + ...530e20038eb40c49bc8b045be0cf4e7e6b4eac.hip | 144 + ...77735a36c325706bd19a12df66ed0839b032b1.hip | 144 + ...ad71883a19b522486706d3705700c012a6fc19.hip | 144 + ...ba0a3369d4e4eaea1c902a90e6501f232dd57c.hip | 144 + ...f1e7e478a2208c4d32e2d7e6abebdc16bcc5fe.hip | 144 + ...f28230817c9d9805c41dfcd4e834fe302e1df1.hip | 144 + ...fb8343e623e46f01893a2b61345d1ca5928671.hip | 144 + ...fe51f982abd60e567d4238d3266fb60e45814b.hip | 79 + ...00cfdc5592b7440d72482a18781e9cf3afb05a.hip | 84 + ...1992a2634cd6674076611be54197c715ad8271.hip | 144 + ...3975efd767ddf7c12e308d948bdcaf0968493a.hip | 144 + ...3d98ff43fbb80ceb82fc22ab039bee898969b0.hip | 84 + ...4c6ad28aff1976c6dd36974ec3b339aa3090e9.hip | 144 + ...5681d4e5871aacef74bdba9e368445875252d3.hip | 79 + ...920c3239bb5796b1ab2fc75177eb3b820aa784.hip | 144 + ...bb7b12cdd9b8b522af577e13232b2459dbd38d.hip | 144 + ...e6c7efbfc831e2bcfc8c1efa1a486c02627cbf.hip | 144 + ...ede7a18f3e3d5e24f6c70392413a2cda16ac15.hip | 144 + ...10303a0b79f2710eb7c66896d3c1f8b12c04dd.hip | 79 + ...1a0ce432c27f4cfa51731c3ef181bf60c8a727.hip | 144 + ...1b91c16e0255fe7a0a85638b98d94634e143a9.hip | 144 + ...1deea4f4fab0db31d46a91228601f0c272d6e6.hip | 144 + ...20538073888bdb3174a8e9c32d7449072aa753.hip | 144 + ...3d5273945c5d40cc05c2660af2df1fb7a15f3c.hip | 144 + ...4576e8ea5d59d7663f3760009a00a19e1b0667.hip | 144 + ...d571f4fe576fdb17d5f75a558cb6747087c7f2.hip | 144 + ...e5a98163e878c7697e554758ebd0597c2c1760.hip | 144 + ...f3e4d4d4837a0cb33b78c4f2767b1d93da0850.hip | 144 + ...127a63d56099e08125b16939dac82f0173122b.hip | 144 + ...4ac5a18f57f2ebb65f7e356e858ab0d59b2133.hip | 79 + ...54b107e1b557ea36b5cbaf7fe3dfce05415c86.hip | 144 + ...ac6c0e61b65c9422c7f30fbd979031698370a9.hip | 144 + ...d0b777df1328bf24e070ed4cdf8615bb2199fe.hip | 144 + ...0453a5c3828c1358360f31f5d3b7258e17fdb9.hip | 144 + ...4efcdd12184211c74e7b3f2f30fecf1041ca32.hip | 144 + ...757a8bbeabd16a44d149ab188430f6d79ddcaf.hip | 144 + ...e0582e1aef74f9209de638b553ec0671476258.hip | 144 + ...4714e4f33340859c106a3129993e22652262e2.hip | 84 + ...5064e27ba427cb951f7e1b01328b0beb6b2b7c.hip | 144 + ...5ad502dd40353312d561e9f40aa478c16ef5b1.hip | 144 + ...5b5932f6df9a194ceb0d69220fba9596528eec.hip | 144 + ...5c161b725becf059fb4439c668edd454ac77d1.hip | 144 + ...909cb5f96a4884caa0d2eb8c5e6bc7fa352797.hip | 144 + ...b9544e2a0caae2c9e3dd8bbd2c509e8dca1379.hip | 84 + ...e81ab2e2678816c7b516d2d4c50e8cb5874c68.hip | 144 + ...5c6c0bfaf98f6e655fc443246b81fcc730fe97.hip | 144 + ...73e1fc0015094861ca0c1c81bacdbe0c5b8f37.hip | 144 + ...da56a4eb08b803332f25bda6209932d9624acc.hip | 144 + ...ec97bdfb6fa95e057eaf5a8138853e1c0884f2.hip | 84 + ...0f65bc99ca08eba66564d34f72f2769bff9491.hip | 84 + ...36096f49a89730f8af7e75457c88cb8ae64165.hip | 144 + ...49a1b8f4c1c6d37973ce38593efda1de8ce0cd.hip | 144 + ...4dc4ed02eb42c3fe303342801ed3073a0dcb8e.hip | 144 + ...6ba4c996570ddab77b6ff1e2a0101b638543eb.hip | 144 + ...863830fc5d43dc6d6400280e892bb7de2892d4.hip | 144 + ...90b771a4f9750132f549c82a88b4ab00dce5c7.hip | 144 + ...b09e8513646fbb2a007544a63ec9e2b04dc4c2.hip | 84 + ...daa59f5dce6fc3965193ae37d8c82a3d1834e6.hip | 84 + ...dd0165ee91c095a19ceddf08789e3576912590.hip | 144 + ...de618ff3ea9f67b90f2227fb7fcc74ea34183d.hip | 144 + ...f63cafbeb445408c884727b473667fb479675e.hip | 84 + ...37b7b6e04e1caf43a62bd6788a75361cfa98f6.hip | 144 + ...840494c4fa78ff399c0399b3ad7ca3d22d4587.hip | 144 + ...8727988e47264b42b4153dc82fc1a750f08db0.hip | 79 + ...c0dfd19a08d61586758091370acbdc6f267017.hip | 144 + ...c25cfc437d8bd803860e39a45b2f3b9fa48393.hip | 144 + ...d3eacc320104100bce46235fe656e5a8223c66.hip | 144 + ...0d45aa85c0daa299da98c277cee826fe67bd27.hip | 144 + ...57148f457557ea80ca56690e525db3a4b0ff55.hip | 79 + ...5ce4b3e9cc392ceafebc7fe3bcbe05aaad4bbc.hip | 144 + ...d08c5470a385d0160b2c1441fd1c30fff1c17c.hip | 84 + ...daccc4b3a0f90bff39cb4597f8b7e484613d9e.hip | 84 + ...dfdb42c1b380e860aa5609302f29698dd27923.hip | 84 + ...f4b869ff23874b6bde0aab68c419108b7e69f4.hip | 144 + ...2c64ef01aa228277d031a74df51363f98aa2b0.hip | 79 + ...4d6cdcd81a456125ab5e0875466c6334d8e5c8.hip | 84 + ...4fcb56caa8f80404789fba0ffac447483a4d84.hip | 144 + ...784fb4c0685d7b651f4113f3c71e050881f3a5.hip | 84 + ...a23ded424200d0c6f06b1dbd0a7b7b0e7b5d9b.hip | 144 + ...a2edf232786d458e2125f8dfeda8847f842afa.hip | 144 + ...af8763f289dace1054bdcb4dfeda28b0aefcae.hip | 79 + ...fce1e11aee2273620e75efe4aa0390fcde9ba5.hip | 144 + ...0569ae9dbd693c0ab3d6ba69704d31e451011b.hip | 144 + ...1b6a64dd181f2efa65aaed03a3d229b3566c1d.hip | 144 + ...1cd6b60a97e7071518cbd1a63abb8b910df024.hip | 144 + ...3715cce8935439f90172d141050d78c7e76fb7.hip | 84 + ...605b2ad3e3753c5f255678abc1690b949c5abc.hip | 144 + ...645b713821371161a9925dec8a3d6c157ba1aa.hip | 84 + ...aff499ad527be5fe33b8e92547df57af26d40d.hip | 144 + ...b99af9a573df50a27fccbec3fa8e350f1854eb.hip | 144 + ...c9f975891087e6eed6393629b41155deafc509.hip | 144 + ...0ac8e8a03f8e7ec2c6e993dd39f09f465dab57.hip | 144 + ...4ac01458df3f240e0656d82330f9de23ba9651.hip | 144 + ...4b3731883a5f8393d60d27487f8d017aedd3f9.hip | 144 + ...e82799f4452e148c3e02acd6526cf30757eb52.hip | 84 + ...edfe3e3dc3008b928c8e6dbd50784b905f189e.hip | 144 + ...00779c17b7b21c18e1308e6d765fe02a7945d3.hip | 84 + ...149eea92f2c40c11de3b778102fcf9b6a006b8.hip | 144 + ...23b36cc3f56d1001b2d3abadd8a5628fefd014.hip | 144 + ...3c8c746055851217a514321cd735eaf6937263.hip | 84 + ...4b8b52f4a98801e185e2f132b2f80c29dd0c37.hip | 84 + ...6b79c4ebdcfd239cecec58203606bc123bd6bb.hip | 84 + ...6c30148a6fa816937f2f095802264d3dfa0273.hip | 84 + ...03eea8075cacec4d41fee7dc4734f593ee79e8.hip | 84 + ...12f23ef88ae5d7b161d36f42d22a5ba53b6354.hip | 84 + ...13fe25dc90b3511fc259cebf463376dcb55d84.hip | 144 + ...145383e39dec0e346b5094401acf85ef3c2075.hip | 144 + ...23b191785c97d284675f700a7baeb52a2eb791.hip | 144 + ...290cc4c3036c9205e689cbcc60e7d16b97a7d6.hip | 84 + ...33f4c03e338ea7c6d8f759c1132499bdcea059.hip | 144 + ...73df9ccfc1ace90fe3afb5c00976deabedf6f8.hip | 144 + ...adde8780b39f1364c572a19c3bfb19417678e3.hip | 144 + ...bda8157fb27d544e049fd7d2ec735725f1bf44.hip | 84 + ...fae2c18645d36a181a0bdd2d8ca7a4ac0f6d1d.hip | 84 + ...2773721479613ad72e334510a248f1436b38d6.hip | 144 + ...67098db97b3f26e71a151c63b74260bfab21f8.hip | 84 + ...6e4dcbe9c4cac8f7c8c5d97ce384ae0cbdbfbc.hip | 84 + ...901a63986cc28ef24cab012b32114851a8c1ec.hip | 84 + ...061c204d8a85c974676f4438994a0be9d69a60.hip | 144 + ...24ee32b178b6bffa7a71603d6e2818f66177a5.hip | 144 + ...37609afa8e21a761dad6b01ff3f26346e450fc.hip | 144 + ...5835bc6f000d3a3379bbc38d90e83dcaf867ee.hip | 144 + ...92eab7de49033f5480c5e86a69e675db0d2a19.hip | 84 + ...c23b7f8fcc4e4f4c81f5f00cfd345b98df2e0f.hip | 144 + ...c3e27b522320dcca5ee84fa534b03aae2bfea9.hip | 144 + ...07d8b5666423da30a95e3b2cabd3839d200981.hip | 144 + ...29a515d14dac02066bcd4701285b9916b43cf5.hip | 144 + ...6afccdee4107507a64323e17bf12c46da2b92a.hip | 144 + ...74887afedbd67928fe4d596709f9ff92530611.hip | 144 + ...822ea727fb3543e445e4000f7e6ebb946d6a3b.hip | 84 + ...9f6e1d59132fe96709490af25bd794f267851c.hip | 144 + ...0d0cf55d90b3f3c9eecada1db93c420f34b1ae.hip | 144 + ...5016bff9e5dc37184d2b9417eb351c7ea1c322.hip | 84 + ...85839ee8d464c5a81b8dad9839f5e0f4b467a8.hip | 84 + ...8f0bd93b352d28c5b6d78f4332026993f0bea4.hip | 144 + ...ae1670fac6812b2d2cbad973e4b475509ea504.hip | 144 + ...b06b43d5d65429e23cc717448cf1fffb0cfd74.hip | 144 + ...c4135fce01e8731fec7a78d0cc0fdeeae28b90.hip | 79 + ...cea8f7b5930abf76eecefce92d0db785d2df5d.hip | 144 + ...de2ef18e2174ebe13a6e7c8c2a6b05a6612047.hip | 144 + ...039d422a57c159ea4dbcc867d766ff1b356a07.hip | 144 + ...08afbff5def8bcb4e823657ce01f57c9dc77c9.hip | 144 + ...184767d723f4995791848cdc68bd948408204f.hip | 144 + ...1a7f9b1afeba6690fdc0d0d1755ea89c805573.hip | 144 + ...34b6ef496d4e0d8fbbe10731d4a7b1c136c036.hip | 144 + ...3d625c5ad3e871f5a727ac946df642d988b9ab.hip | 144 + ...4d27535b9570b8f4b790470a83c1d0a9a2b6ce.hip | 144 + ...5ba6d73f331c76e696953606c5b347b6a46f3f.hip | 84 + ...62a8db637d32e7dfdb2521cbdae6e1fbbd5fd1.hip | 84 + ...818f3ce244743cb1dbff9aca399df90742a6d0.hip | 79 + ...91797c1474a368e9cb056b50b4629d7736c3cb.hip | 84 + ...9e54273c0ea2358fb573a7d918aa7b09fe07f9.hip | 144 + ...f815ef540060cc7ed43e1c57a28e1d080c5621.hip | 144 + ...10bbf37503bbc92af82bc3487989b41b20ca85.hip | 144 + ...11806cd2d3ef1127f676b2d98bf8fff2a1e5ab.hip | 144 + ...35634440edb25cb095800b882c70aaceca1dbb.hip | 144 + ...67d442001d2b167e70e8730abde4d4461b8569.hip | 144 + ...9494d9ac35eba6794a4f9120d2db9932596ef8.hip | 84 + ...a8d021381083bc48b7fb1840729254dd8e5137.hip | 144 + ...cb1cfea1b0dbe50a02252cba99428fd977527e.hip | 144 + ...e93ffe7fca311e136e42fbcd12b05c9fc7174c.hip | 79 + ...f5339054f47d9ed6cc7f9e66ab21ce3bccf3db.hip | 144 + ...1ff66d2aeb47d2fdccaa4bb6b9d066b380c99e.hip | 144 + ...26a187c4db06115072a5132e1166b5b03368b0.hip | 144 + ...36bc309877917a18fd21acb30563c7e2f233c1.hip | 144 + ...5359f0fba3da9dfed06ddbea8fe2a33a9cf40c.hip | 144 + ...6683d175affaa5ff261ab8503f64172d8eba8b.hip | 144 + ...7eb562a7eff31d589e12945d80233aac202ae2.hip | 144 + ...85901d66dc04b1143bb6404445baf65693b781.hip | 144 + ...b9ec2cccab94920e40f62a1f0f094acd919d07.hip | 144 + ...0b2bcba57e77d975ec5304fc50cbd09cddf4bb.hip | 144 + ...4bb75ca79f805a81fbad750ad22f6d22b0d8ff.hip | 144 + ...4c9eb48da49a61957537270d94e56cb4e426be.hip | 144 + ...5b1c6758d4b8540158299dd0362297083084c2.hip | 84 + ...645b3888dc8d1df50c47c0d75822eebd3eb019.hip | 144 + ...66feebc9a0dcc508ce002c255154622875e524.hip | 84 + ...cd68acfca68d1acac94f493e25be0ef20f209f.hip | 144 + ...2a198f23c409b715761b702d7b0e6e5992701f.hip | 144 + ...35773419a9b3631698a3d375d829af55f7731e.hip | 84 + ...88f0f7363804cf5403adef70828ab32d09a02a.hip | 144 + ...966fa1ff013e477b1706928de6cb7f8587c154.hip | 144 + ...9d9baa269dfbb30b714389d1733be51cc419b7.hip | 79 + ...e48d7edfe9513f24ad9fae68cac3aa940b17dd.hip | 84 + ...0f47a44400de385ddbeb99475b717c5646fb41.hip | 144 + ...1a3b7d4fdfed64e64f7a95dbc64eff541092d6.hip | 79 + ...3b86fe4e153e0bfa8d1e75f3641fe32b0c5149.hip | 144 + ...6075c3a5fcfe63ba12e854bb1fed6873f014ab.hip | 71 + ...6edb824cecf459a8ec51b8dc74b1e06369aceb.hip | 144 + ...c1a31a1d8556cbe0b6ea76faacc78855108539.hip | 144 + ...cc934ba7baab1a2eb062df1e4ee5066e9ffbc3.hip | 144 + ...d85ad2c9d197f501267fe0804e6985802fbd18.hip | 144 + ...762543d3380185e304f84749a70db1b8d3dd8c.hip | 84 + ...8fd64c2f2b27577109a984e6ab82f5f0fcb296.hip | 144 + ...b629c37cf94134693ce455b8c88b72a39df7fe.hip | 144 + ...bf6805a489739abb77c13173d57723e9304afa.hip | 144 + ...c9f955f227430c6224ebc347649386be7f01eb.hip | 144 + ...deafd2f36cee29109fb824e0135407453adcfe.hip | 144 + ...015c5d50481547aa5754d042d9d7040cf1c7ff.hip | 144 + ...07a1b0d5a8f94e0a0f4032f401d20b4b643523.hip | 144 + ...34e691714f0b99773c2ac515ed82de0f387065.hip | 144 + ...4b7e452a4db74189334697e3a240ad68085f0e.hip | 144 + ...89d0e4442cd8304081892ddc75043e68a6398c.hip | 84 + ...65193d97d43237c22c04478ca5833011d8dc8b.hip | 144 + ...77abef05ff37ec27705eda51896e2aa3a04966.hip | 144 + ...d9a2396ceccdadab24602f30e9070901a76dc7.hip | 144 + ...02730dea6987e2c038446c448aa08bdcc23113.hip | 84 + ...14c6b4bc75d95a150104a17972abae77cb47ed.hip | 84 + ...2e3053f30f780f346fa6b7a836ad2554cb85df.hip | 84 + ...6757fb17f5e94a6ba1fb14540a68c36d571159.hip | 144 + ...78ec9e09d3b78dca6b5bf0be1538657f02f319.hip | 144 + ...935fbda313d3518f142f43d46f56c600f69286.hip | 144 + ...b2bb9f8466de1ad5210e4c39ee7b8ecacdffa9.hip | 144 + ...b65fc519ea7cfcd19f7eddbc3acad6842ff558.hip | 84 + ...c5079636a4a31a849ce8a5af89d50330a74628.hip | 79 + ...ccd5f7ddc894b2717112cbfc766804e02b7bd1.hip | 84 + ...18fb4e529104fc90069c8779ce5463460bd516.hip | 144 + ...38053e01268a4c5883620fc6a9901951e2e01a.hip | 144 + ...39a1e84faa98477b05df71d363b9ff0f9b2760.hip | 144 + ...8a9e05debd456a9975953f7b0d510e7a0f6978.hip | 144 + ...973d75297bd2c3432a7c88e8a9ee1c9ae693bf.hip | 84 + ...b53fb8d81148ff384d31a703bb4c2e7a5a33af.hip | 79 + ...e0ec1db1ea308e226f675e68e29b839e41b252.hip | 144 + ...e6b10e73733716e71ebf5a53703fb935fc5e02.hip | 79 + ...153f9a9b0b7c54ddf2debbe297efcffbb4fcfa.hip | 71 + ...3a776ae4ba68c23acab1a5a6381684051738ab.hip | 144 + ...5c757c67aa23cb88e1aced6fcf36b7b28391db.hip | 79 + ...5d492ac3a6ab75648056bcf26250a4aa929cfd.hip | 144 + ...6879f8ff4796f48ad87ff8003f4f6e6adca9a0.hip | 144 + ...ae1294b6dea5c8b93c2b814fa7460c4047105b.hip | 144 + ...b2eb64b66d46359fab44333c2c484f4c9dd5de.hip | 144 + ...c0a99e949baa5f3a7ee2d6e84427982f82f76d.hip | 144 + ...d37e7ee96c392fa24c02a9143438a3a7d05741.hip | 84 + ...de729aa50c10d8101ef504138c3769e3286753.hip | 84 + ...3c604d1b8260958becd1c7c209745ff9151715.hip | 84 + ...9bcea4393593313d18a4aa6dcb44cd75bc828d.hip | 144 + ...a9427f34bbf5ddb28a39161acc36806e68f2d0.hip | 79 + ...d8fe5f4f8641998b8b805a20b2ca92d019ee59.hip | 144 + ...d9b65558398c0c10127b560807578ef117d7ed.hip | 144 + ...07e8d1089557dfcc95a05160be5092e9119a53.hip | 144 + ...5e3908479965856843317c8b0c42a6961dfd23.hip | 84 + ...86d5f8d5591f3e0f1cdfad19c38c420fd93023.hip | 144 + ...b04e6d5527ba0b8089ba8bdd264e2d5759338b.hip | 144 + ...b53fa68641f45baabf40b7cfb8b35a9a1b9c7f.hip | 144 + ...077e68dbc1bed2dd20a5f4dd35e0cad6330ee4.hip | 144 + ...591185b1c5f521023e250a26f742984255b241.hip | 84 + ...62567e9ea16771d8445464c38f5a2931cb355a.hip | 79 + ...6a6d4cc262ea838dbb83ee747112f95fa297bc.hip | 144 + ...b6cdc59bf216f7045f0cf5f221bb91ec415cd2.hip | 84 + ...c353f963c52624cf79e82cc2b2c02eed94b677.hip | 144 + ...c5952f46f4f2bf06257b00661774eeed48a323.hip | 71 + ...278488b2cca114adca5e4614d86f92447f937a.hip | 84 + ...b241b947a0adfc8e50c5d71765c14af24593ae.hip | 84 + ...b9abf5b09e63cbe76390bb46ff7cbefb3141f0.hip | 84 + ...171210efd217c07d357fcf42e5372ad7e9abab.hip | 144 + ...3deb1382003ac010d9bc1c59d1878d3ec7a727.hip | 144 + ...51d24ab5f24e003ed6751ae8ae5b327892b15a.hip | 84 + ...7ec8d547ee9713aa3b5b667f22cdcaa8f62b2d.hip | 144 + ...7fc24902b1ebd8f2bf8088b0ecf6de8be8362d.hip | 84 + ...9f63a538940e5ace02ae5b5ddc01f730adac4d.hip | 144 + ...a613eaa8471ad7da66d2f8f2b8e07f6e02b467.hip | 144 + ...d7dec90b3c62bf3a30bd75d3c6869529a06b01.hip | 79 + ...e60111633db08f765b3c7cd5cd768cbd030255.hip | 144 + ...37ba962e0288e2840eb0925d016b5a7e3b3164.hip | 144 + ...6bdf67720e938d538a867548ac3579b8238169.hip | 84 + ...e81dbc4cb208ef6e684c76ba1eb451d37fe10c.hip | 144 + ...1a43f2210a8d1e5623411c95c33424cee5e747.hip | 144 + ...239db5a67c23a383590a651f0d8a0be43a13c7.hip | 144 + ...8e709eec7aef1fa681053c6d2969a5ff18c45c.hip | 144 + ...974931e65d6b16b7c868d462b95dcae20b7513.hip | 144 + ...b0e96b759e18cf703cfab0cda1385726f6e0a1.hip | 144 + ...e408cf9456ff977aa7d12345e9b2f1e60639f1.hip | 144 + ...2ebb4a86e7ed0001de9c5e607b66fe8877409f.hip | 144 + ...40f0acf1885096efb840ec5600ec421c4db331.hip | 84 + ...5421703cbfa63a58ec02701e245d479a1fbfc1.hip | 144 + ...7cc2aa1ffd38298b52764a93cd1271b4d92f8d.hip | 144 + ...aa0cb33c71cb8ca7b83dd0e7a6c7b01f6b50a9.hip | 71 + ...b9e7d9af47cdf79f15f674f8976c05f08b0ce8.hip | 144 + ...c6a7b25710f0626c3af534111b161e1459d2e1.hip | 84 + ...1468c62c878295443981662e037ec5213cf7a3.hip | 144 + ...20134822739be6fa0bb3d98e9dec79f025324a.hip | 144 + ...209426a8e6bfeef7d8ae7b16db791888142298.hip | 144 + ...28af9e5e3c25800dde938e991aaab4fc1d64aa.hip | 84 + ...53c9c32518b895daaa3521827f37af78836fb8.hip | 84 + ...69b38b26c30bc770f74c856e47eb498f5818e7.hip | 144 + ...cad48d9bc80d58705ea60eb2dda4baad68cedb.hip | 144 + ...246d1013d954a9316f4432c986d3be9459c548.hip | 84 + ...2f1f1b679cabab04218037ef370d2c7e1fe332.hip | 144 + ...5c41ddb04ec7f80235bb3db19198dd6b699713.hip | 84 + ...8c74becc24a93427d9c0838784e9b6caad6e81.hip | 144 + ...ecc90ad7b86791a9e6f73a582aeff30f393804.hip | 144 + ...1596e8c608a795ff971aea8e199db9e72b65d7.hip | 144 + ...4bd5b92ce6bba640b8ec6b4e53fe35902c5572.hip | 84 + ...4d42e820adc1a26a428d59df7ffdd7f8580176.hip | 84 + ...4f26e45d5cf567d29fbe375fbf8abdec39186f.hip | 144 + ...5b87c435bc5d7d85d738f3fdf68947d79f5a77.hip | 144 + ...80e1639680ac1e5830a21f921bfe2cf364ef42.hip | 84 + ...da112b1e07c44fc8a7f19368da203f6935049c.hip | 144 + ...0316cfe49323638f71ba688dd8ff9b2266b335.hip | 84 + ...193ea266f3718398bc5622f8bc7042c3527a42.hip | 84 + ...4fdb8294257d951dcc9c4fa7ecf1192568b91b.hip | 84 + ...6aaa63ed42a578b953ebd614318d44cf44e8a3.hip | 144 + ...95bec57c3b2e6e169134dd8d20b287d7405134.hip | 144 + ...bf7ef503bb026258b3ec3d82d3ef1443046964.hip | 144 + ...d0166931e4406873d8f552a5d5b61fde2391a3.hip | 144 + ...fd08d56f8a9be1a8dd104cdb1ac58e283b5064.hip | 144 + ...ff73f82aee3184849d04c2364eaa45c6d0de9c.hip | 144 + ...2cf0e5fe479690883507028748b0cd3dc83cbb.hip | 144 + ...658c32d562f9d60c5ca1262a2e0df2375063bb.hip | 79 + ...8f8b681a405bfeba5aadaef40f32367ec5cd2b.hip | 79 + ...900c0a5c0d03dc17d7a907ab40652d9920e756.hip | 144 + ...a6438394dd3427f29aa0bbe58ad1f797c3c38d.hip | 144 + ...b87f983a5e84582efa1663f84da76cf60b5f6f.hip | 144 + ...c803838f5644ccc6f04f7c8a6233fed0b6639e.hip | 144 + ...df1cbfbaf67705820f125b474469ad7ebab0c0.hip | 144 + ...0fa4ea674a590d0a817367ad9915a5fce20c51.hip | 144 + ...1f1a11f778d99a00aa5959a3e58a41fcbfb1e3.hip | 144 + ...25b59df454ccf53da6cb201e0aa8d09f52a2ad.hip | 84 + ...7f84892e2a8496169b7406e63b0d4f5aa63aaf.hip | 144 + ...803aadd93e33567aa6b23100ce4fbb6c040dd6.hip | 144 + ...f1797f6b672a55476348571ce17645c8a62869.hip | 144 + ...566441ac3074578cfe45758ba0583c0da0a5ab.hip | 144 + ...72bf80a78885428b2c02e522426470653a7351.hip | 144 + ...82399cd6412fed6a1141296a7e4d42078f7b29.hip | 144 + ...856ca950bcf173571766c3f04de4163be0402e.hip | 144 + ...9548d6cced86c21c09c6475237a0cb926df0ed.hip | 144 + ...9878f4ca8cfe6b8d8748766f66a1ef8eab20ad.hip | 144 + ...f102a388ffb05c690a20a29cfe0b35a35eed61.hip | 144 + ...035f4bfd8f2f427720a07e3c311bccc1dba683.hip | 144 + ...1f96ce4dcc7f789a8ace73c230c203b05ff6dc.hip | 144 + ...27911254904ce4341e4ff5f8bafc430b8cfbbf.hip | 84 + ...31289837f915e2aec1bd01eef1b3c1b099864d.hip | 79 + ...9def2b4edf6d18f6ef1d6b141f9e0435441f6a.hip | 84 + ...aa9c39b06e55bf4bc9f9a2a0fb075c9d4e69ce.hip | 144 + ...cf08242b3fb1c643d4149bec985b667b9d28fa.hip | 144 + ...51da732f397624717160f89271514bc334b59b.hip | 84 + ...61d8693f82d22e2c5b1abbcbae5f30f4433e5e.hip | 144 + ...7790f260630f312b84888dcbdf849ce130ae59.hip | 144 + ...7991cb7787a29d3ce4711b4ce04c5fb6a14ca9.hip | 84 + ...0410c26d7649e21e2ae5e32e7af89d84d2ea70.hip | 144 + ...2e9a82c879051d6fe3c42108f8a574187704af.hip | 144 + ...3bc23b8a4f1e0fc5c5756c4e1c835bf59dea09.hip | 144 + ...3bf815b520a9d9e17b43bf9d7fb870751b6225.hip | 144 + ...74b12e83e214c30995a25631d37df1478927af.hip | 71 + ...824fb32933b27501ae8a7f43f460a2dda6a814.hip | 144 + ...8a6b193fec3203eaa75819f6b51aa45a48f212.hip | 84 + ...c58761c927b222112cb5cb6c9acb5d3c915785.hip | 144 + ...16fa84278b489af253b52839786f94aeeac36f.hip | 144 + ...62a97675719c2e8e9bb97361b92ff1c7b9d2ef.hip | 144 + ...85f869a92f0482605e52019828244b12e12b44.hip | 144 + ...bdc143c29d5ca50ab1e96a814bda6d05b0d5d2.hip | 144 + ...c5a0f98b94530befd634891e42c424bb86f0e1.hip | 144 + ...c99c3c82b77946f6844699d2333cd532a78a26.hip | 84 + ...f56e45b2240515e97fc1bfd552eb03b6de5094.hip | 144 + ...f686067fa433cea5e95dd523846dc881eff635.hip | 144 + ...2fbb135d59028afcf867c2cf08edc323565528.hip | 144 + ...4c15452f9155c5966990f09432e5eb7e28e785.hip | 144 + ...4c5f8fecfbbe16e6648becb3b5ca89fa3d8a94.hip | 144 + ...5bb49928ce5515d7b297d5eadd4ec70a22d60b.hip | 144 + ...79e1f9231692d736dbada062ed6821f34927bf.hip | 144 + ...9477a613665cebcad781389ba7c5a36f51efe2.hip | 144 + ...a36678d5047ded97ee7a7ba9feb9569afdb6ea.hip | 144 + ...a47fa8d9b5375bc408af68b67345ab9dba2eb8.hip | 84 + ...ea85b766bf0c918ee0baf24dffc6a5563d5105.hip | 144 + ...eec221cd63adaedceec39db41ea942f99f5133.hip | 144 + ...030b61ae20c4b7d9b2d10930a17e01e9e93328.hip | 144 + ...1790325b59bd44b0a5f6cf9723a25fd845cba7.hip | 144 + ...1eb85a00017efdc610e4259d2abe935b85304f.hip | 144 + ...5841a729099340d608e31023acbeaeade3e886.hip | 144 + ...5ebf0f2200f37ccc0849e0c3745f6e2f00111d.hip | 144 + ...7b0916744b593435d8e1e7b6d874d760cd5e3b.hip | 144 + ...86c13e933cba40553ffba31d53aad27415ce4b.hip | 144 + ...b0b08e29b2e1bf181fceceb9dc416e54f52b00.hip | 144 + ...b6ef39c3db49f26f736d6c9221dd825409ec4e.hip | 84 + ...be827108d252b2f5847fa8e132c9c3e56a90a0.hip | 144 + ...cabea88b8e290688c1b360875d228e6fdf1624.hip | 144 + ...10a3b937e9659716925e39a01d794914b08e26.hip | 84 + ...19d7614f2ed5da21a52ed172ef62cc07c9c01a.hip | 79 + ...26e43ca652e6f58ff48c356165aa4349833b55.hip | 144 + ...345632e0cae0d549ba79626a08b1885711deb6.hip | 84 + ...3558b4c7a667dbc365c4c2ceda646975408f51.hip | 144 + ...614df484b263deae3b3c20adb0ce7b62eaa651.hip | 144 + ...9cd1305633b62b68fb8474ce021f639f8492e7.hip | 144 + ...e12cd366d6850ce26afce98e5076b695b4875b.hip | 144 + ...245e9ea974adce2b9807d33b9ba12d916eaffb.hip | 84 + ...72cdd69944d2d765478d4aed13066a02b76f6d.hip | 144 + ...8b8c3525fe86a20a2d6c69585f3e36c16caabd.hip | 144 + ...97b7adcd67ed9bda8831d1f3f1ca7590c6d251.hip | 144 + ...9d98dbec5096a89b116f85675af772f023014a.hip | 144 + ...b5e77111fe1e20bafdb83a925b5faeeb6214af.hip | 144 + ...cd7501265b4c4dcf015485e63e2324304f70d3.hip | 144 + ...cffa403b3631b1957e1a9a06f18fdb3b4eee5f.hip | 71 + ...453e3bdc9752cb7b81f7cc3056325a8b9a8ad4.hip | 144 + ...6862dbdbb20bc63a650e1f93e9ac169bb702b2.hip | 144 + ...b5b7349a671b182d73c8016590f26fe06a4cba.hip | 144 + ...b8adef0cef91a86f36872407fea35df90e8f2b.hip | 144 + ...c6056d9fe125a4dbe08c1d86354e51f7daadd5.hip | 79 + ...d868d49abdb769ab82c21508d655daf54b8a99.hip | 144 + ...f7aa57cca501f221077124359a589b3a6f9d0a.hip | 144 + ...fbfcac254e33926131a71905e93f9cc0aef89e.hip | 144 + .../hip/flash_attn/ck/fmha_fwd.hpp | 773 + .../transformers/hip/flash_attn/ck/mask.hpp | 157 + .../hip/flash_attn/ck/mha_bwd_ck.hip | 407 + .../hip/flash_attn/ck/mha_fwd_ck.hip | 360 + .../hip/flash_attn/ck/mha_varlen_bwd_ck.hip | 436 + .../hip/flash_attn/ck/mha_varlen_fwd_ck.hip | 364 + .../ck/rename_ck_autogen_files.output.txt | 1810 ++ .../flash_attn/ck/rename_ck_autogen_files.sh | 11 + .../transformers/hip/flash_attn/ck/rotary.hpp | 84 + .../transformers/hip/flash_attn/flash_api.h | 503 + .../hip/flash_attn/flash_common_hip.hpp | 53 + caffe2/CMakeLists.txt | 3 + cmake/Summary.cmake | 1 + docs/source/backends.rst | 2 + test/test_transformers.py | 4 + tools/amd_build/build_amd.py | 1 - torch/_C/__init__.pyi.in | 8 + torch/_dynamo/trace_rules.py | 3 + torch/backends/cuda/__init__.py | 49 + torch/csrc/Module.cpp | 14 + 1840 files changed, 249657 insertions(+), 38 deletions(-) create mode 100644 aten/src/ATen/ROCmFABackend.h rename aten/src/ATen/native/transformers/hip/flash_attn/{flash_api.hip => aot/mha_all_aot.hip} (96%) create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/bias.hpp create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_00042c36bc588e60a7c8a9ba297a8a25d8ac0660.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0029076f83a3dc695a167beda6fe19230a2b114b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_006c417a52a1bd7c55e45d111483d26f4480caeb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_008f2429c678d13386a06e8d8b15c4b480940ff3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_00a2adbe938d458d51ca5fc4020667a215b672a4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_012c0f480917c329f4c3c6c666cf32af2d82b294.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_014c209d5cfc6b965bfd78c64bf132c0154e32be.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0153ec18d3ded0f8bdc6459ea5757ebd94d9faf2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01ac1a2ecf9a487809e46faa92e267df2d47de91.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01ca79005067e20e4eed5a72ff9187cde702cd1c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01cb354dddef6e99e4ac843f2adafcddfc58d520.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01d12033d59ce2799a2a024e5d9232325ccf1320.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01d3b034a2d8d0b83c0aefa4faac6c3f28ce737f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01e2428c5447aa9a78f79f73f31cf685c586872d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01e8aedb7b7d77f44a46b2e9b7a826f245aaf4a7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01e8f0df0c54ce619e5b66441b3c96a5e18b05d6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01ee0083f6df962c4a754cd3295b1a436c590a0e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01f74764c3c3284fdd1b67d0ea781c2261ed0de6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0225857454eaab2eb664aef7a0849ce12c32fdf9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0237c76137df14fb808ade8bd6837045f2aaa5c9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0271bd8b7c270e1593871b638288a4923342c446.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_02d88a03cd3966dd0cff550065f58c3ffecfff6c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_02ff94e3c787a7b06ffc90c25777fa74f225e32c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_030a759dcc92028b4c6f317fc230b98cb929e806.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_031b12f9fd94e01aaff2c0da4f35f346822087e4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_036887daf6cc092e7422a17882488e59cecfb643.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_037c6c80fcec3eb8b0bef50ad6af6d27bf5447f5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0392491c5a6dfc742c2be483419a40f6a7a7ea56.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_03a71615a088e972c998f9c7cb44566c268c5124.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_03ff035717140f7385282419598cb4fb2881ce8e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_041a0718891596ddac1fb0088637029233ccbe60.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_042a156e9eb935555ab14a84461959b466c2fb5b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04641230fe9a50a221047f7a1df8a370f72805b9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04c363e11d202c6d2f4bb753661c5a2043edc0ad.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04caeecbc01667ec6f5599358a0a20423aa9a00b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04f39b453505f68a5091f68b1c3de48369d1e7ea.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04ffca078cfab8bc6c4ccd1cc8994a1bb4a88ea7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0502e718337eab7d47aa65cea7d3c5f641484520.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0513b2f3bd8ad51315aadb7f63737201898adca8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_053981d9e7af2ebc0f91e61ac5e25cbe68c95bd8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_054fda16133a0d25077967b05425f9128e1fe1a5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05538339c21c92c53d237865d72debaaf2ee5075.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0595316f0dfffda03e5296b959a49ec3f3c48d67.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05dfe927fd64a564c5fad537fb7c41ee9c94c2c0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05e60b3ab7477f9edc8576a8bf43e3a62b8d5ef8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05f794c7023cbb7e35f1fd1ae45bd2377bfbc520.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0628931bf5cc1daa6e106cf60bb21fa1aac6b1df.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_062c8c3c1cf6c33af4574099e9b6ac54a55ad776.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0682150e93f547e00f13cd8984779bf49b91e50c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_069c663be0267c009be4814e9e4e7c13ec999411.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_06ae52ef937cc27c544e32025ea0dadb7fad982d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_06b74acd9abfbd1c4ec2f4c718eeb92a0bca7bab.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_06ba94794a14f0f0022af6f5f3c16e1e16959d4c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_071751b1012b90f7b57f8591cd06ae1fd27d9cd3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0766e7aa4b263a811408b285213e47176ee2bdaf.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_076b3beb57b30afb30636f948e3989b346b38d20.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0789852b0cd3cc030c78b28f2fd5b6b0546382a4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_078b96ad691a85eebd18586db0b62b8911016d9c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_07c3fc96d2bebe546dce6ebf46e5c7a519959599.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_07ff04fcc273e469737512893ea3fb5876ac131d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0801c56831b4c6428200db6318638a2129bb197a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0836d5dfc0f939ab9a4064b403339373caf35b56.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0842c4e3aabdf55405b3ce09ce1899245ddf11ad.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_085722b43cde5f37242edb071f639da7c4a0bd48.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0878b9aa31429d23a93cd953cc6a2fc5f43d0d3a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_089a347aef8a920e3b59d5ffe71fc5bfe002609c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_089de13222caec1483207d4a54249f8da4f9c151.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_091cb49c1958fb4342d79f367ea93cf2b472f785.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_093834d4d3fe76e1745e4482c6b51b550c6f3dfc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09513bff5c1da6aadf11d2e8272a422eabff21bc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_096863cd93d1b105a617d0daa1d4f37d7fb6b893.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0968cebd81ade762c2f92fffc0153fa7a2b91eb5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_096e888c52d0f4a5847d7515fcc66208b1ff40d3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_097b3e1dae9bfb2e89398706508f8e01966fd4ea.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09d76cca48b71dbcc9bd96734787209fee4c9a74.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09e50367b62bb09071e28b44235a7c112645a706.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09ecb6347009f6a5d5530a6acf90f9f40288cbcf.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a2b116fd5065109aae46ee547e4f49ad0e9d6e1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a4e76d89b175e1d9fd2e9fb908d5fce1ebb945d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a55ed15ef58c941e06dda890aeb530e28eb7bba.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a672fca51de618e3441cf8764e8e83eb782f2c7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a68c2f9a3acdd787b81be455cbc7836c8bfd90c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a89417a043556970f72eebd48b4f3e7ac15377a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a92671b6ea99891c0d69b1c793f4d131b9a82ed.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0aafb881e34a3794970a1282af740b3f19c138b1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ace6e29e1d3060c3086c08fe27b471e375f9c75.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ad9d68fcee021437e13ffdf94d78252205f5a31.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b2647b5982405a48e8c8888552a4b89386ccdd9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b2efefea81036641561bed80c75d77651176f74.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b3153af7bcdba33115a0d31f121fd76be2ffbcc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b532fcf26f90c82a792cde7943634f667c1d033.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b90a0186d8b8004e3f19886c7992c8e04d0e066.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b9585ba1c10acf67115c5899b3546608541820d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0bb81407c8a2b3cdc5fecf655b3ad64d5d729cc9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0bc7910aac798f0555e9e505ad7f177c9fbbd92c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0be8cf70c6be969ecfca675782c860b5b75ac089.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0befed50a89d80c22b2c8c3d5ba67d73c3d0190e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c32a2d9701e23dd930119c4ee8089042b5b0ac5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c3b2ec99fa7b09c7f78dcc3142a661d686044ac.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c8a0bb89a6f05289c0405df5126fa0cc16252e7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c93c65e5942a2f43f2e491547add02777dd2eee.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c9bd38b8f9009d932ec49204fdea39a52885246.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0caeedaa7d50f1741d618fb6c573529eebb075b1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0cdef49859c80c6b3ba18eb2fb4c35c72abc1cf2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0cee6b9427c164d78994150305a47f73954a67c0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0d0e0147a92061d32608a34e7b47bd534eb787fa.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0d13a4c8d169877da6408584dc1f20a6f7c5e3aa.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0dde401aa76cb5425563cbbdb0362748148da3ca.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e007c36231ccdae12f102eacca1f74b0711b9c6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e0a2370f2a320484d8f9f21e3197425c2dbe9ad.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e1dbc9c433ce8ec33ace9e62550261d613db582.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e3f4cd28a4c06cc109f6a0798a77844bcc750b7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e661b5f30566d1f159f060c264849c7ae4772f1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ebacd06455ab20eba78b389462946716b5819f6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ef309b923172f4c0fb38d9b9f5325b33b4877c2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ef9b9413697d6f4573c6605bff6f58d027c5016.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0efdaa9266a5a464009297dc59db92504f8bf1a3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0f0c699d9c3b0ed62097e38ba05e40e815cf474e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0f588dcb2ef86677ebf84e406eb802e9921d1f1e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fbb0bef3b388867e75d7a8a187b8b4b650a42ae.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fbddf533661642d84bf5a16149692d5a892182a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fcb7492feb79e27e0bda73e57ef7dab410e2bb6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fd4068ea93fcf4df463e3bf3a6898d23b65da7f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_103186dbad604763008e0204a1ea90baecef8877.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1037f1bc50c4a65dac09ba56b701256b701c4322.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_10a055e5c3d6a953d470db5dc21449766248058a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_10c24f1f9009e46afa3a59193784cc2575f79056.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_10ceed95b0a0a01f844678717c88e0426fb503fd.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1132b11429034d96d82c82dbfdb69e460ad8a564.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_11e7df31541c3aa919e9825ad7dc4432f9a03c0c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_11ff174ff2175e9ec22ac3a0fa59dd7713b79643.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1211733062ed30b876f1d63bffa642d77e258dd6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_12207f4b6e7fac27d6c16493a5373f448a2aaae8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1241814f76107d74ed069ecec99a248676487eee.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_12d5c8a4988efe60ef7943ecd73e18a28a736583.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_12d60c8abecb3bc9b84b0ea7851628ab17d8b0b3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_131691f01cc7f29affb88152dd48c7a484315dcd.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_131c1fdc4206bb952b2fea675f24e3b09f605eef.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_133c51948cf8584900807998da14d788039f53b9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_135ea67de101135ed5fe04f5cab1ec1d7b3714bb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_137fa6780d9e6bde10aec10a875c039fdbbc652e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1386cd75411e61a8dbbaf2b916e62f4f5f99104f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_13d5f2ec83b3331654e37ea0b44d88cd98abaa37.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_13f747525ad31e76c88774fb2208e470da9c2310.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14221590b90c48d3cf259fb4e834ccfaf7f3209b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_144f19363ef26efd36f0436cfa9f84f181a8824c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_146eb8c40e3146e06936f3141b2c4d92a578ddec.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14baaaf1e90a075ab802c6e7d97c4b1605c8bd72.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14c4ebd1792c781d219bd21b691b575f64635730.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14d11aad7b666f500f68b264a2fcca6dfc5f1a05.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14d4630876785655bd4950566e81ae0b645c0d3c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14f77aeeafe4b28f314fde5ebccfd2a554872781.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14fea611f3c253aebf726af3e5fdb7e63e18e13a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_151a4425b411596c46c7032f6b83d3152a0e0cd4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_153e897098539c3466da9d7a37234daf16476277.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1552dc38d26f6badb7a9bcb5ce9124d54cc45ed3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_155bafb551768855c8c01faa63e44764ebe6c110.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_155c3549d067464d186a99b8205317cc000d4898.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1573e3d855d28c54af612ab950b081302891d56d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_157768cd725813f8111d265cfdfea7f42034e5e9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_157b89d8d625b8244b5cceaa4d3e5fc5a09c8989.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_158d5ce564c3ae1eefb54e3d41dde2604560ef4a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_159ee1f1b44d1a8fbaead65d8449413bb616d15e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15b255dde1a9d915e582ee2a83de7d83190c6a24.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15cf7068183421b141ed5d6e7fe902d06b6492a1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15dc02ea7e0908cf0bd48034f5a49debfaa36219.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15e8e1ab8c63db96843054bb7a98d708ae6a9c44.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15fe3e8f4add16a088fe44458353fa7c0c4f9658.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_16047b5544acef40e39932672cac6f562e200948.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1621507cf219fe608715d4e5bb6e5764022e2d61.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_162b0dfbe3f615b1d164290799b2457437a0044b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_164a947a6c2ba83a5b1cb7074aee0bdac6c9c64e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_165dfb45658df8f1ae8dc0738ac9614740f2576c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_167f5328b035ed59a6f05dfee31edd704c4b07ee.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1687ddf65ce4ed2997583e20fee9f201e86633b3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_16f94f5c65c37624f5458c165daf83517d9e3c81.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_173c44dd85077e6b12dd06fdcf6b11ba349e1866.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_17b9b96edda151072215502cc2b606bf1f6f0b03.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1847fef2c06ea581b0ab31af1cb0556c572696ad.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_187963e1969301abfa61d06afc97faea2bb4efb1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1886d4bf54b3a4a9e093360998b2059b3c03d072.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_188a70d526394e254274df95de0727850820326c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1899e28aff2fb168cdc3af7132dd7fd09c2e1ced.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_18a4d71b31c451a50df7996e3db864bc3c3882ed.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_18b92b4e249195ac3e0c74d246585a4c9e0992fd.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_18ed7195a9443c84956c3f32839cb3ab9056bdfc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1914250fce818584291c69a5f058a58cfbd83df9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_193699a5daa14ca2def07489e0b563149bc403f8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_19af6a7f9e5020e8d0f0ca0f6258001f6ce592c1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_19cd9f7b08cec83736605af63d9fcaf463a1aea4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_19df4e13108e043361e9528b71df56f04f696a0c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a11dd5ebb989503a1c182684e7f247e2f8cd9c2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a236be9da05a07d11cd28034d90cdf89941a172.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a5e18f6333ed2cce509f07cb8bd5868951d66a0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a6785392af35e27d6697b584cb6f17a766d3fee.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a6bc2762b95d550485aa720edaf71138d94cd07.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a8da3e6ab050262b659c801ccf9a14787d7f176.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a96f0ac76f117e66eba97cb990c2350561ec2ab.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a98bcbe900f8c141136d18c114b02fffbe8bca1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a99b2625adffa8215276bb88fc65bae944b846b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1acf2f892742b1d236d2b31a8185c6869126adad.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1b3e7c8969027d3316875f33dc50fe022e05ce37.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1be43f8b629e7039f57b95866d5777273377470d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1be746990a2032f0363ad9f9112cc994983f4706.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1bf767e7104cfc8322f26df35907fbf04b8948f3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1c1b0f85e085dd0769c566fb16aafe5ab5952714.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1c2a2d78176e3f0a78e3ad78217e75a4430c0de5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1c65ba6dba01da9caa84ba89453b61d81376763f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1ca3f45d0be2d1119cccd0af042a3e8adeda2ed7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1cbf88db44aa5f884438288a325270d29c7a04b6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1cc459e57bfed5ec7f40ea4a4dd9f72f3ad7a709.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d02609fb803ea2697e2c2cef35e6f923d2578cf.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d0b822743e0205f60521d38d7c64f589fdf0f58.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d21263e16dafe79b9fe2f998847296e575c14e7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d3ef3d5ded0dfe2a0bafb52ea8f841658db35fd.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d498e418ebbf33bed58b4074d1edf3d9bdd07c5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1da23de9604b5d98fe02529075bad995954c12ca.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1db03461737f1e359f389a8d297476f9b60faabd.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1dc6e599144a093203fd7f92ac6d3c2cd7180d49.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1de2f97d49f015b9af0b186801e939c6f357a0c4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1df893ee660d37fba7eaca452ae65b3e45a73087.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e22f2d99804198c61251b4629a3f18ed3dcd42e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e33ce1fa113b221e5303b4093c2c4e748ce8298.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e42736d4f677a59a172bd6f162616a437696351.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e7d7888480b83c78833214b32e10f37a6e20301.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e9130607a2d24cb0662a47e9cf12c6602143838.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e943fcc2e64c618fc1415b3f1a0db4d70aa8494.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1edaf9d4270d2ac61c299320e06ba73f44730364.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f0cad6ad5b172e51c569e84cd54a19b4eb0ed05.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f13a6d0f8c798c0c4ba4ad202d081899fe081ab.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f6bc5faf18be193212217788d476ce6fd384bfb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f7faa0b33a9aada86f032174afd40d18efa7715.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f81f8cce0d77dec9f977b9eeb0778b70a13fa75.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1fcdcb750f382fc7828a9886585f50efbe5be735.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1fd9fa7c2e13d0bad5fddb2b5a316bbc09d397ea.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1fda1c96568eab89a8f6498f8bb23c1223cdc7b0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2005aca3520b171bb82d10ad70fef44f28c19776.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_204a573ce6b7d2f90aede543939315561cc43177.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_20588bcac681a5d69f252d7523a3681a0c6b6181.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2081430c92864c29bb9f409e7c27caee1de00749.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_20d5c3c86398f6ce55abc90db3e362dbf9f457f2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_20f7ea0aabd069362ba4bbd66623cea5b6e1a6bd.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_210ef512b7862837f54acbc3b21e135a192647a3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2122c973581930ab7a4ebc90b3bf1cdaa229a87f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21411df58165946bf02942b597d94de7dd856987.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_216806a4598c885e517e664fc8280c59ec3cbf11.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2173b7c710d418f44dc2b41bec5905024334eae5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2177d95cdf45f6fec95d1812f2ef183a75259e38.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21828c7d3f5574690f12f841c27f025206e6165b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2184fba2eec5899bb40d49d4508196e6be1ec1b1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21e235e31d6955393ac8e825bd69ead70687b7c8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21f860d42fdc2cc6bd743d53ba546e332c22fedf.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22105635385fbfb5d2f330df83ba6747bcb27f6d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_224f9af5e5ca519b21b71a54acb49f50b4999c47.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22511de2592b6e350737e44865e1fed6496e3f32.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22632f996eb63fbe4bc5748c5897b775087446a0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_226662cf1c9900a4334d2cadcc5f5ac3ad355f05.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2273457ac3be01cc1595a015a5f598f8290c77e4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22a07ecf1a59f72ec6bef3e970d7f33cf54c5f44.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22c142d869ef940ca876c93033ad53b576ed34f2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_23047ea90076e3b0a3eb0586d49b9ee74ca6d279.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_230861e81e5acc523fa680534eed757b7b4a4e1d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_232f61bf31dbb5de5d7039d5ff2338068a759b68.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_233132e712eba8972ba444c604f89e01c5b84cc0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_235bf652702c2976551778b9159e09188575c63c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_236b3eef02b904304348b9d35f715b639d63218f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_238e4c1ca112afec494fbe47a85b553302c43395.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_23914c00690ac5c4f89cdbbaf00732ba66c5c0ef.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_23c9b46da8774462de8c24e14b12df3ed596eb57.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_242013527a0266ad479715ee3e6ae01c45de29d0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_24410fd9a4150c33186a2a365d06d8f6ea621c20.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_245d90000b55ab8b6055b1934880fc6c4870b34b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_24643917fc970c043d1c80d8d4b17ec92deeb8a1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_249668a3212cd00edaae871758be30a5a1fea589.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_249e6b93baae25dff97a0bc9145a8d328ed3f317.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2543da478310245e19e6c6a0d9ed7ad99540b3bc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_256ef175029a43e64164176d4eb212baf9d27bb9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_258d747083272ea657604ac84867ecea17bd65da.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_25938733446b6c0dcd159719f08d04a9aa467967.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_25b3225da1e1842f83592971a1f62a0fe30aa9d3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2660282ad39ef034fecbdb74acedfb48620b7dfd.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26835ba70606c769e56d19dbfe74061361aa855e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2695783ae8f0034692efd6563f789ef03fd0f4f3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26d77b228420a3ead919474ec9c6fb2800f86890.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26ea90eb5a527434c1740933a1d2dd863eccf14c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26f90358e522d7bb7c76c3a2c6010f0f38788bb6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2703018e71d57d3266fc35e2e18a78faa3dd52ce.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_278639d44a4a8372a627a7c31e9527c8faa26f97.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_27c2000d32c230a57a6712f27bc0fba02722f5fd.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_280bfced8745fbd9266207463fb41476dc23afff.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_281d897ad17d7f6db2741b396e6b85a9b8f35286.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_285e61dad8f63fb973cb2eb899c959e400622652.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_288458c5a0720ef152848713119ebce6d76db6d6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_289071756e7d0582eb61ce6483fa3c988d2e10b5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28e4d2c757e4b8c366a2c320360e21ff0ef671a8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28f1ef32c4384ec26f3dc5e3af6a74fc8cebae92.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28f2e2b108a53308a0cb6c123c8d318cbc2eadb4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28f7634d29bef11fd466b452a46b0612f38c949b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_290c484c2a366258941ee0051e139ea716a9de2f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_291a8bdf9d63b112e7fe5fa7e8835a6789cb8ecf.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_292454f2d82184ab0491ea0675750c6ec55d659c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_292b4f995d622826af5d1f2bffa7ba68467c841a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_295a523f815eb822d66162d4feb75fe0bc50b648.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_296c5836ba118969c4ba89ed62a98dffe3105738.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2995d39cd62f20622a31f11a292ed175abb5fdf9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_29bffc159b0bb826ba489ae763dae141bfe8e802.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_29c9e5384809b21f39e78bb2e43af345a9a21d19.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_29fe68ba10b3480dddc9866c51ca8b5efe962cc3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a3a980a26682d879c3a3425f3ba5be3f5761adf.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a45129fc4995abcb8f880692f11c6186fc01641.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a833fc01e88bd8e256ef64ae8251dd0ed10720b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a97c457144cb63a9c6c3d6be613b47bd0df9928.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ad492377add5c8f6d0d2dbf9ee9e4338bbd9f1f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ae344010d49f7f9a6caab2cb84be7f87d2d96bf.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2af6c5be53732eb1939a2f93232af7dc011dec1a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b0bcb241e5a1be1d35366461408d06e095a26ef.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b3326e055da32cc979892a2fbd0f7b003cb9f98.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b3af90387f1d227119c5dcd4b71362940bbce52.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b4050988e5790a28dbe10b4c20e14f10f6cf85c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b49a9b0801a06dd89c7f7182d7590b515df1592.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b50073f6dfeb7ea77d5dce288a1d2f08f8f6362.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b5317b6cde327a842170ebff20c2b03d81379ff.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b8169ce4b4b9a17ac96fbb232e6a93f22071ab4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b823c3b99e7c8d1cdc39a5dbc7365a383bf9ccb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ba934408c75da5479cc41f96b98ea7d333635ea.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2bb6da1095bd8669c0e48b5cd808cf0dcefa2674.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c0bda0feaade2b554d648d72f219ac9c389bf09.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c2e75e6f659a500dd3cf2cfd65118f111342119.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c77bd7e89ed832cc31b2995566a49bec6e4cb52.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c7aede7762a524a7a424cc4dc46e43fdedf73a2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c808da5c2514806c2953bb77d5692e5d7c97aa3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c82e3c4e445e1e02f14435e4ca01a90850139a4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c9756060ac0e73dbcfc58a9222a78f0283cd029.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2caba3ab83239e474412fcf89fe0fbef97e51bf1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2cf351fc2c2da4a8e1760a3affc9a5947c6b3bda.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d06f77a4054ca615d96636c0e2eba2a89850142.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d1f2d1e57095f756ddd11e8e9d4f6f253e3ffa3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d23a26e0a59a8323dd97632e610d24624143fbe.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d43460c011b8d5e01ea98c9b8ddce962de59a96.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d446754d7000673779d15d3e73039fd3c10a720.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d7b637e0313cb423b22cd8844cc2997b3ff73e4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d9a04b7f41dd6f0db017157a44790f35c626e2d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d9c659ba43bb907fd4e3e36a50958288bafd1a3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2da2b905c4ce32234c2af62328adae6b1f9217a8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2db33b5442d2e0948762b1f2147a321a9d6907be.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2dfac5a83def98340c8786d55a30a98ad68b9eed.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2e30f50071113dc4ab59468d568ac9deb06b0342.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2e43e401abbfb1b6737e4dc822f68421abbc648a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2e8b4260626beeac76c26dbcee3cba1457b30e99.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ea394a09c8691a534ad2219bedf73724b6dd5ce.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2eba937ff6d0302ab013db7349d4feb914107f1f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2f0247e301a7b076b6ec8a778c3b47e330638963.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2f32f2d658f1f69840fbad511ce8a3851c859d52.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2f55a23a0f24ff7062a4c286944f25d2db3e20a4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_30024440e780fdf9ec94deccc85216d8bbb5788a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_303b7b04496e4db7c1ba2436485dc7c8a4c88448.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3076a6de0e2612279e0ed64612f7393856bcc9ac.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_30c8e4d5c761fda50e010da779e8e4730051d403.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_30f0200092b0e18d57a9f5e512d565f1c0229436.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3108502fd29d3a24b32177bcea968121ee809115.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3110540b50e95e99a5cccebe47d9d3a83093c2fb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_311104394c8bef8d4ecff35c1409221e723a5a8a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_311731442b756308c0a869f21b7b8b103aa613e8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31222e158484773d2257f4a31e3dfbdb68336a8e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3163272d25bc2db2ffaa1fea87648b45ee68d408.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_319df310195191895005b30151da8c1afab6c82f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31a968898f0bc6366313e41eddb5e3a3ed12dc98.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31b807c48c472e9b1311a6037cd98e21d6706889.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31c3760f5978baf9780ce4587ae4c768af0e49d1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31c4b866692ba5c3d115482bef4790733863c1fc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3206cc121ce8955ed59ea3b12b858ee2e0cf82f8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_320a6196b662a1d3dc7441a9536d825dc356b95d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_321500dd4c41e4d68834814a48a639f5ca36a2fb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_322a86568f89a5a5a165cfffbae9ca6949f2477e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_32438250078ba2a47345ec4955dafb4e4de78a25.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_32527660fa7aeb9a951a9f2fc3c53989bd141c48.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_325fbcb9e503e68fafea08abf86a4951f440850f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_32652a27e8605cef59c8341813b68e7513be23c5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_327e27892bc57f3dec0da24f94f2a483d6c9321b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_328a311bafd1c153525393b252e4170f8aafb370.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_33099fcfc218ffdf69edb4f2f0e46121bea9fafc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_33746071156e9ad46f403a539dc237e0a44122a7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_33e7c1e5f41a451c7baff54f7238b220f1bdf8a1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3400f0af03743dce328486f8fc805dd30bd6da31.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3408103188e27b3bc55dce0c1716c0b4d32d6494.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_342d29c85070f488a14b1915f948e5fd69019c99.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_344932e2655d7b32704be8de9a63bbd8c3369f02.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_345a939a2491166dc520e9a2b9de7e43671e0c2b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_345ea796c8d97bfe3b7c9663bf15e2e5e7696235.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_34807a8e90bf1cd839f32fd718afa6469c35a4fa.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_349241529745bf138552f49d9a93db418663ad65.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_34c2db98d8e2e690f499f41cfd5afb831b756f54.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3511c54e6a6f9eec378d8b661121066536195d3a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_351425a006aeeff4d69c8570cb6bf1e1427d2c21.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_354121d3bad1d448bd413718fa096f54faa12e95.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_356f83cb96d0313abcdb24955edd4264df72aed7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_357f7e626135cc9176a295f3d1f336a7c3852688.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_358399e756ed5026baf3ab78af17489dc07b9532.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_358d28c958c0a831a615a4811d13279b18db09c4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3642b78913a853a62dbff8b99d9ae3fa458f461d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_366662dccf2f650bcd8123c49006c759cd4c0ef6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_367e58867c46d96c9bbaa96eaaa9f93595c9e099.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_36a0a960541bd8a2dc6741579de685b7c0a5f6d7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_377b70f54cb2778b5ce3df936b477f775eea8b3c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_378759ae25465c32960487375828e23c5f1ac869.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_378bf438642e5d863e31145ada2a0688059aa5d9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_37ad61bf8427a26775969f8a9166fd0bfb7446b4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_37fe04467e87ec2110f60c7aea0cc9bf2ca07481.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38010c9bf7341588f071f889b7a0b4dcc4e7a14c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_381b29d9888365bff0f109d897b508eebfd8a61f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3824e97d5ecba46e06d5ec1a9456c810d80227a3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38273a2f8e6bbb42ba0b0871b6c95abb34531f33.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38a5ff72f22e0ad040a281e66b1aca0bf3a2aadb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38abcbeaa4d33d3150f2b0238bb62ebbfe960980.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38b94d76503e13c911781169fbc378517332c42e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38bb367362fe2c4849ded728ec5dd00969ce188f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38e12dad9e3bafe177ed3c27c833825813e18fc3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38f8a89468cf9c8606cf12a930db062a83cd0ea0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3937d9dfb68351de2942e32f35e2ca1ce71edfa8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_39422621a00ff79b2f5ec0dafb957c77693537b3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3967a8807c9451b09227c0f685c18aafeb062fd2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3992d5df4ba2e999caf6889a852db4e1ba078e65.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_39d3071347a0c98f3221104036f477aa13bffa4d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a1dca5feb864e8981387c2d07e62acef1730aa8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a2280997eb6f1d091094fc54cecf42b7c9c3a2d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a2643099365d0903c799585f41dc1a525ac9f9e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a6b9566559ed2b1c85f2bea1c55e72c41dc47bd.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3af86f458fb4dfcceb7db3357fbae0dc15142a15.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3afbb5ac9048a962a60f48886728220ae6c2aeaf.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b26eafe76cca8e74e819220b6de1f4279d48e43.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b4ecb47f9ebe8c2784976c3e9bbe4834b475cf1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b508b92f7e123b21658f6e17d624ffa87831fee.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b5b3c218e4a7b459e54080e24c5b730221eac02.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3bb129e6dee6848043dd0e8fa812ae80fec4d014.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3bb3b682eab96e4e173affad75b9d8e73f1dd690.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3be7cea6df8e6dd56194e1172f28943667f1c4ef.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3bed3aaf24c73073c604a3b23bb4b0358b8e3490.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c1454ffc1418dac641f63671e947d9f550b1f0c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c38bb80e9880335faaea81985ed5d0e713ecb08.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c3b7e4b8c1efe59f79a15512716fce2282a79a7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c64c33870ebc329921cfa3867d58b1857421f65.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cb0cee09d633b6f70febbba63a1e090522cfb4a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cce3baac1e3ca03af0c3f4ee4d0158ad1031e9f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3ccf0a9d5a5451da5dbf6075ccea45e4a140550a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cd7a9ca49c1149d46f6b05b0fefc41ecaeb6ea1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cf45927b6d931e31e2209685d787efa28eed8ba.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d1cea88a2277b87d405025ba256272a1720f88d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d289100991d4c8c362f64c8f6c4ba395c2f3495.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d3f3eb2f5eb1f3287879604892b1c230df85f1d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d45624dc6e33c477c73a155500b015b6c010de8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d55cb42b0096a8ae338ce100f86e378aa1a04c9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3da8c31f6d5bcaacfa4a21aed4d1d3caecb48922.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3dba3cd44f78c950fe7ceaa5f0629dfc607b30f1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3dff884e176ec7cff86d17c6afe1ddaa4dd6007d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e143d88eaa0d9cfea856b2f3a57d1275a656627.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e2557f206fd81d82a3b9d59113105040beb891f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e562e6c3af28b8478020ce3c3bf73c036001c93.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e61b019e1398a6a3c36143fb84b5ff22c9f4508.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e839660557dee9d5bcda9b56940ce23236c5f6d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3eb2ea922daabbba131b90713e06d8caf5f30662.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3ecf565a5a1c4a09887c67ac3b9a019dca427ac0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3f34433b784d1e405ade3378918641372a30bf6b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3f5e01b4f2ca8ea10898c39d6570bd74e85f46ed.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3f7315955f555768f24585a50d75e216c40f062d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3fad30ff0739ab5dede67a96e859f8c474c245f8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3fcc6893456a559c7d22714116022fc69b372266.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4018b1fcee808b6cccd131418b6ae9e8bf900d8f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4018f690b6322588041bb467beabd8a7bc79a2e0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_40357c5e9739eae136a7abf92bc38d3ac94753f8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4052ca6a3ec02f6559e4bbf1edde42ad2d127c26.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_405e7efa263223148318ae96bd1929b382e994e1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_40aa64439b80ff8dd12498b3e5f6b625da16e285.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_40db688a9189e1c47c300d474df946a248a63303.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4118e3ab290263ed2576feaf22a1944bf2ddcb7a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_415b183c50dd2663dabe3eb8b780913b778c54ab.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4160f6b6d0869740a5a411abd80108f729f810eb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_417b1cb14b67dc82f614831550f7deb0895bd7e4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_419461cdb5687ebbb7bf0be136071d70420c1619.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_41b68458076e6cb129d3ec793e95b91430a0c8a1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_41db3f29d1940e59dadc357c040ea37a6ff208d9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4217a48a1677bd26cd48e512f1fc8830a8a551b8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_428ce4e14cf94b284ffa735fe03d923cc74c9fe0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_429b82a27571ac91e3631cbdb7e0a58155abf962.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_42e2326066c91452335eac05f25a6311376bd9e5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4306c6c37cf472ad262f53941611b5e60072bdf6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4347e039c003489dd528faf5d710e687321a3fd7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4356b3a2ff49f72b91a6b9c215df285f2798ad47.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4377ac04be3a6cbdbfbe57612a469412812fb5b5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_438e3565f4c720e6c9691b0d33c1392936e2e7ae.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4395d3c96b3f4556b9765fd0a3b5701b2fb10948.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_43e7c78e8f65be35e2753a0ad5123118555c56b2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_43f2156a04b18bab55af60e9357f28d8a4604e8e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4409f2a7deb027e864afdfc9975d3ab93c5dcc9a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4432c5214c4d40c54ca2d02f0d4785c6d6902370.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44462715ed5f192532760d6f4c66ff9d4e20e254.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44564dddf8b492d80be54854abb8d1d831e42679.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_445cd8fa559588f4264ce6192f2de3e3065365ea.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_445e28a8a51cd435130ded2abc9fc606e522c713.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4462b192a64efb60d5484798526278ac7a0fb9fa.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4466b6c6b2ec3acb40ac1cda432efa1e4e62d9d9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44690e48f30657b0fcfa26fb3b9af3ef76e792e3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44c181996532676f2140fd026707135144e9d37b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44cc95831c347212021c0bab7b43acd7daabce42.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44d82b58fdc3e5b7a7c20490ce7f5acce4e6ec79.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_451fbbdc2dcf2ec81efce34673ee6c425cc16ca2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4568af1b2f104664fd05d21ad789aed39ecfa42b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_457eaffbff3c58183a656687010daa2c16cfc26e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_458d708d13577f2b92e6d5adfe952a87e0cf7be5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_459c8fb6028991321b09a990c2188d854d940268.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_459ea3713aef9b916e1b38a882a45012930924d3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_45b9871c220c0065d74bffeed4021d0304a9625c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_45f4363f50af1e7ccd24751d5f5b181bf32c604f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4601680af41c8738089ff377147e0547dcad114d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_461737a13e24009bf1a5a4b780175043a9f2e33e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4666db0ff7b035e54f2c0e59acedc2131b722a55.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_468a5f057fd5cef2df5f919f5102f47e86901e3b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_474fe2d739eca8c93fdcb2c105d4154cee6ca1c1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_47548aa042c69bb9c59a8bf706b44028aaa41830.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_47f3ced9b5ddb0dfee8ed5e7df8eca0bbe273047.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_47fe73f04cef91cd2a0682e905483968ff80eadb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_481415463f0316ebe25ff2fda47c68cc54db3359.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4824e1f8cda50f80988857611da766685da94494.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48280c91d7cd8712fd533e246a6b0f758834abc9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_482e34930d11ff493007b1613993e01acc1af78d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48300e0aeabe337785d4c7b41796ce65df6cc42a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_483eaea4096c8f5bee16a64860432f0634a253d8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48435e5dd23e49e19dd313f9891ffec800ce74c2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_486f6c7c7655c34b7b9973ff357b0813f0a3fd7c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_487724686efd35731e5335efa949486c93ae26e3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_489e7be0f85656d012a6451b65f6c1d2613b187d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48ae3af78583258c4b13c11a442022e0e058bb85.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48d7d145f96aa8958a9208d0c8887742a8c834fd.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48e9e858abf6f77489f3fadc4ee81edacd26705a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4904c5910a2d0595b39a3f87652a9d1ef4fcbe80.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_490a68220a7b621ae9817d7b77f55de239b0a4f3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4911bdd71351610d55916d452495e599960d0a41.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_492fbc418e829f89bcb8d93f8afd2869dd8dfccc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_49d4c005d723cdab9fbc307933c1257d114b539e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_49f5017cc0f5c8c8dc71492e7765cf729c1f225c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a06b5b153ea6e8b1e20d9aad9d4633333fd98f5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a2e6b05e7e4de2cb23d815f8b2c8adf22131c0c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a4a00bd6ea27ff20a2903d619e1361b5e27672a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a5dbf601de5754c03a03a1a42395dc0766fb8ac.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a9f3da698a6103caf25d785928dd9f814ac27b4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ab5d6e8fbfd92e9f7e47bda5cfbb0d4162a6319.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4afd02981f92fbef6277c1985cc479c12bae9239.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b1eaca3c37a82d19f8dc91f06764170069ca3af.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b2e7f96b095ebfb66ecc7a75752fba2a63e4f37.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b30f472f00bec9da0564ddc40e07112b5f9a117.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b45948f2795293e72530b02669c4f549608ea7f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b4c03c916393d6be7c5181369ebcef949eaa763.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b68e4d00295b294320b94bc777d7d34609127e0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b7393d55600c9892558248f4131fc06a6cf3309.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b74439f42140cdda9bb0f78d995d741212a35f4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b76e5dce9af523422782dd25d8dcf6f25edc68f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4baf664bfdf070362bcc91af77d1bc406f744351.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4bc48576f285325345fa1205e5e7e01787b74f71.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4bd4d46397a3749646b232b306688e52b8c6e584.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4be4a98f150f3f9ab6f03b5fd0968c5454565c9a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4beca56234ff6fb4f23b9b24822887fd9a3d0df9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4bef4d120e71bfcfe61d67aa44d24ceb907c2b9e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4c0c50a1fac82d47dff2357ee3ddbfa0b2c8d487.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4c69d06e3f32e3b6d28d3e54ad764b472741c193.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4c8720923c3452e3aebd7b9c1b4b23f0c35d7e4f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4cabdafad0bf803223ba5e8f474cd59233dc48cb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4cb1861e31df98bdfd731efc3d335055090d83af.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4cd3de43cc1f7588d62a10362f59d113ee818846.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ce03571f1d2779bdeaf0a6a2d617e236d191c11.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ce671f5defd76ca08614a7a1f184c36c0f1e2ab.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d3b1ae63e127b6e6afe39e354d4995afc5faeaf.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d5f3cf0f78f73df79665c26b20b0805615e1b04.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d65e58c9f147498ed04dd51fe1393770603a6d3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d7dc0f356b630179916f8fc2041b7f1402b46df.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4da9e9b7277bc90518ab92860bef2097ba96d982.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4db2e63cfebcf84043f79be0321708cd159c62b9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4dbdd9c3f496a27bde68cf86374999ff2dd53505.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4dc87b7d385e7b092e4706c464217b004fd8a6a4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4dde56efe17f4fd36a11cc959320a5e43f1dc232.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e0a88ccef04e81b8c684b695f7cb4310e448915.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e15e4f16de26068cba30ef12fc29332d45e460e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e47f8fa40332c6ed12d9971e0b539049a871c34.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e760de14b71a41882ec4a2c7362565af36d1a5d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e79dce18e49ffe024fe4cd0693ad3399f5edaee.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e9a933b916285d9580a76df543cfafc88a536cb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ec2075f394acfb14fae7b1ef4304fd9b654ba0d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ed6da5357b67cc28aee4afa9523adaf055c4e32.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ef35d82ceb4af2e07719c16109c6d72eaedce67.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f0aded9d1baec3125ce8e176248cb146ca580fa.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f1e1c969b57659e7e1367ac9ba10ed5ef5b69a9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f44435491aa68acb3217b0e693232c67641a2db.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f4a5d56721bb1a1332a65882132a8c5763932ec.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f6243c6850c0a2d2b7bf1476e12f95f187257b6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fa4d21931b9afcbd70b1567995d3eeb6f9308aa.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fa883a36a76edb276a66c5d779294f170d6d4b7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fd34faa8b168e2ac7862641229e6146d3e28aee.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fe530cbf6363a8f08a94728e45e88ecde299e7b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ff20bafbf156fe8fb80bdd84a5d2f3a4a944c1a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_501dcf3213efd214cc2ce8c9ba0027f991d241b4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5052b2318dbb78b1a82ef03666a35a623f44481b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5093976cb7b32a8bd28ce92fc13af00a3e21f737.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50e59bd079f4d205b613056f975fd2b4e372ab10.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50e7b11019fc2299d70869253877319b03388244.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50f887556a3540609649744957651ca667b91774.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50f915b4d9bd18a3c25a85917392ea4a5e88b349.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_515128c6978449b33ce0c35b02a9e9aaad65ef7a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_522a2a9435103ed405dc1500d31652f1d431a49d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_523e5bf45ec5008aa3aba4773e68a78e122b2fe7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_52688999141a72e61322140db29043ef9f7fbc3d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_526c89b7a04758b4badbf9695b316f877b8bb053.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_528db08068589c6e4c096054d26a2e5be63285b6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_52a89981a05963efcea7ba5c1e967638beeebbbb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_52a8a323414448c50571a334f29bc0a38919b61d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_532a6ffd8a21d3e98342fd401f0247f62ca4e038.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5344427df3ae9392c4fc4c25c232196828e70648.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5382a30dcf702daae19bd6705864bfe36e09502c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_53bd60bd2afee49b30a583c32a45ae9f2076db08.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5403eec1cdd216d5c4a7ba977e2ef92a0d7fcc8b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_540bd57333c6839ccf5cf2e928edb996bc60c371.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_541874a7633e5713720b9d084b6d1c6715a51a17.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54208a6e8c5263e38f9ffcb062564ab61d2785ff.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5435b4651a90e331fcdcf224282457e3dc038a30.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54402a22ceee3b665a3f24edb98b8398c35c6f5a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54548ad36fb92d0963893146c8db20f53cbf0c8f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5467aea26852aa9a9e3dae76b906005ddf6fbae1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_548b347672451e8391388a400d016803f4c4cf8d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54940ce53998becf9bddf56df7d19894a7658168.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_549b6956eaf678f7eb901567d1a515eddbedae5f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54b6e18b10d529eb6b32d7c19c59eaefc7184376.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54ff49018f1c12b9fa31e523ad40b9cc162ba34d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_555ba79201a585bc091ccfc326fd24e851d1eecc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_556cd05288e1666f5c67fb87ad02ce660e4c589c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55b14cf2998a61611d1de2594e926fcdc378999c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55bd9c4f1b7a0621c67f3e964d946ce22fb2fc80.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55bf8444c1c26b91fd490c7216f4d0f8aa0a1f1a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55cda610c235987e13232e828f8d86fa88030560.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55ea83a47c6299fefa4220ed88f7a8e1dd938215.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_566b4782793c6526bfce7362efbf6bf069928b2b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_566e26d4969bc6bbe9b092bedab11cddb3360c0f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56964a17f902257aca9d08c736516a2c67d9a0e9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56cc4399c5567a9495f17d54c712cc9e65e57521.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56de9a7dfb1201b56528740e9d8a07b62710fcaf.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56ffe9e21362afe9c3a407c09d5de186954931a6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5724d91c1fd6290a6cf8d52a3801ac6b921dc7d4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_572e68bd619e118292768f0925ccf92cbfa68415.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5732094f5917e9164ee0f973ac6ec47245a69101.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5789f267d34c9961ced63ad07ffea2c6d2911415.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5854f09511778dd1779a839b0b194896070f69ad.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58679919fcd292a2a69543de0db94e2985c9d364.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58762476c7f2bb05dce92ec22c0acbeb03676746.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_587fc33d02b1932235b8d152e57559060211d591.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58a784fb478ff5b3f1e2da9765a3a777efda92e3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58a7ab44bbd9fbc97c7805860d5f6ac81d6ae468.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58eb2edc7738d8d18ac359691da261ceaaf71788.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5919133d2ed892745013b2fc5d503414cf0a4d83.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5939e6610e41aff8d1ccdb66d9e84d3e48e8d379.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_594929c433b049a8cf949ff476309a8faf5c25fb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_597a0276ec419f18f060a5186e6bb703ae434ac8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_59901147b7188212b8d8feea15831a11425fe4b3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_59beb9cb4e161f9dcff79080149076488d436301.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_59d366421e0b51c90fa53c366d47ed8d51b3a329.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a05b4e7782bd0e29ca9f6d33fc59d4304136d41.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a216f777feec4752f5882677b18168225da4b53.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a29b93cee012c79d4364502f1d90f947c73641d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a85ae0a16e4b293b549bcb6a3ee52df7fccca32.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5aba1183efe205af38e79a1b2dccea5fa515d02e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5ace1c9b00f160a17355d4583d49c47887ac33c8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5af96b404feac271dac8f4190180754480d3ba80.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5b413bdc825ae863d53dab548f2145dc0de8fd37.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5b55946ff3c15a44b9c741e9f6bbbcb5bd4c8577.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5b7a4ea3bb8905a22ae97a94c354b1cbe38093bb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5ba578c0e7abf1127dd0370f06d7278656c93ab9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5bc803342862aa30e23e5be7d84e611bc571c529.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5be9ed84ad9be1627db7a66af9370679816c0897.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5bead6be6e39ece0e5d44335083336f7f546d2f8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5c36fc744dfb0d985c9113175e76c7ec1c935054.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5c742b9ac6749f189d597ac97d46d35189472c50.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5cd03e29403ad53d6d52e5e81182ea6ff5aff2be.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5cd41b6f578f3c903eb9d58ebfab62eb296044e0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5d707d065ae152450f9def619ddc3dddb9089e88.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5d7ed4c885fb32a0b548186e56d64bab98071d30.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5daedab8931f2eefb649b91e80145cb71b63360c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5de27c4081377f59363c2bf2ea8624217566d2d3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e0abf4e2b6be3e2c555c2134705b9dcaee617ce.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e62968de58d9df7d687d671f37d63393f189321.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e735b12d130ebf849ac5d6752e413ecf3e69fbf.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e840be0741afa4d41fd4789c8300223fdc63ddc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5ea53f7c6370845fa94aa9b395c52fd1900b62de.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5efe77ca5c394a60af0313072cdd132216a52bf3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f20263fd84776f155519b3481be5e2c5b035585.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f3c3bed2b584ea2031debf9f953f5f8f7012171.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f71e663978dbcba859c5114ec675a712e343fd6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f8925f929a5b26f3544ca31938aa75b3c59d34d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f954a393b7b5a7131c13d0c4578443f468a738d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fa19223cf296d7fd10e15e2571e63c84a80fbb1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fa7fafd4227918e0c7f0c6ca3b2bd673cd07279.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fb062527121e627871b3f1b2a94b96c42e51205.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fc66c5b53f83bf1e023e81e9d51f0285b3ae731.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6018ab272d7306689c7dc5a6d5326efea1471235.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6049c01db99fce654e9351e711b113cf7424550a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_606f5e0b99814b0a82a731de36f28024bc317801.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_60801d21c14796c08377349ec86a6c800af497b7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6082d55544b5280b49b071ea277fb1827193fa2a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_609616f72bf16a060fa50091ac139ddc06bf9d88.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_609f68180582384ba81aae2b1d4a4c52dde2c68c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_60efa9c427dc278c0d1bc31189f683cd45e4d873.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61204f6805d5d830aa6fca2a9b5f238ed63c3a73.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61220f6dca850a5b5ccf1f619a267c40c37efeca.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_614a9f10ebc51bde3f580ef527c17f89489c12c7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_615430cb65d8d540836c7f12b3367abd3c8e63d2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_618031345ea71cc17e458eb97a559b7c94d3ae43.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61896aa9e4e4d7e494c1755b1e77a08e0e264f8d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61a44ac409e914c12281f1d26e5b52d8bfd0df75.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61a9e92183ba87924e73ff0b5e25bd12d6038e69.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62048a8ae1c0096f3372b0114c15edbe813425fd.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6214f820b39a8ba81e547a78ed19a909ac13221c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_621da34ee666903307d3a09b7a032f2a70054759.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_628b28f65f19e7d1b22fb3b85b7cf3d09cd54ebc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_629e0b97b3fece7c12504f4c8f1860d611b57269.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62ab710e4acc711430745e05e036dd6a4d6bcdca.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62ba7a5a0f3a714eb5f9f2af20f7bfbc82a30350.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62eb2f81e73d65fddce7ff43c397da6529317607.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_634d530731c7ade2c7beecfd1bbbca8583032217.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6360621af3f7e1e81a8be48fea8d2750fdecbbf4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6376eb68c550b50b9aea42a7a2cc3bda186b0e40.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_63c411351ec59bdbed2590c599f9eddf7807b371.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_63f121a3c8928c10a2d86b487cd13fa995da670d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_643b3798f11997d33ccb58d90ed6c10d5411b735.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_649336d59a8b35919e593217b6fd4314a04ea359.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64a0ca185449a49fa485892fde6af745ba758167.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64b3488ddf3bb1a4870371882f0a5d267bdfdf73.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64c3c1e3dac623f07c2dc1b934ccb868cafcb38c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64cf03c0aa3f1b2a7b76b4e3418eb5063b982a29.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64fe2db75cb20428856b02cd1cc8d7b393a6ad9c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_65794d9c185b21f59274ac5d4db10a7abc0be968.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_658552954505a2092662071401e135e84956c4c0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_65910c8b7a30acc731948ab58467fdbe4fe32f6d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_661b49505cfecbe4ec3e5c7371de3aaaa85ac9d5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_661ffaf653085dd7f122d603bb3ba4b001e5f3c0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_662767e588220d0dc6137b00cc1d8dcc91e97134.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6649f19deeaea20663bee781af7edced7f7a4fc0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66968bbf7e210911fcb95ba90c79837230ab1ce3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66a020f728df204ff51e37d2ddc21afb0aad5e7b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66be70b088b20fc8de464167c35745461ddab640.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66f651d3415562206c1049b172261fddba01ea6c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_671828f15eec2a58be23063a1a8132d337cd26de.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6767cce35ab784aa42ebcb75af7305bc38a8721a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6785dcec0197fdbb50124ab06efa627f1a2c0567.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_678a4a8210a972bb2ed89d6ac754fb79438ab2da.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_67fb736c61088b8dd92fe0371f5c98e23bf9077f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_680e81c3700f130df142c9a37a368944ca548721.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_683e8a33fdb7053760c9c135002b0a94facbe015.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_687f4aaafd1a5b9ee85aadc6fab79ad0c27a2ea2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_688aaa193f332ed13e017e78ec07a7c80e45f6c5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6905ba47078abd7a5b6a51eb93b26095517e7f70.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_69214eb450c3b249017480efb8d092b0edad6dc3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6979ef43adffdb62100270a62706fb811963925a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_69cbe8eca7e3510f5caa7f13419cfbefbf031754.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a3f42d5c9ccdd3807e488b00f02bc6ab5d8d99a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a4b6226b355bf35d4d07aaef1828091f03ad2ec.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a66604bb15f97a56847a7c968dbe32d247cbc13.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a7b6781ffff9a42beebb4d73f0d15461ddd4479.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a7eb3d86aa385f9ecffbc5ba10489e56856f918.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a95543aeed81adfb6d847f78212585a36122ae3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6abeb7b50ae6a1fc62535b9a1dabbde6f177a9d0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6af23d1460abfe875e71f7911697c42fef0f41c5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6af4c15a119e805e4407b184625f57966f8833d9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6b0ef67ce0f178aa2863c4909f5bdd7f766c9b2f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6b638314efcc4f16aa4a6e58e6caf2fda1711519.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6bad2ed9f91bc1efd89ea66cd5c775fa140cf931.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6cfb7075345704340ff33dc0ef7c04ef127f26ad.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d07bf9c05e41dcf2416e05dab4bdde17158db76.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d17b92fab5bee7717bf9aff6a6bef7cee3816e7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d307974bdeeef95cca0d130ebb7aeb77fb1b6eb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d40d762ed576832b3a752453e9881b5fe6d2650.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d470f5c6fb81032fcd7974180297d4bb2a8427d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d5aad18f59e47a3fa3278c7ef1a6372830c33d5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6db86621d626722434f2ae9b7b8ab435a8dd8827.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6dd707cf48a17d31abef94215c5720419faa0a39.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e240106c771ebea461fc2a87b6da68e510aba70.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e6a4475ea795935f4cbf2dc0ac156a33d754587.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e7e1d245baabe2f6293e3d85318f9936b333500.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e8cda718e10824956f0ee39bbb0891eafa45a7b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6eca9cd905ea8b0454cf9564643894682b08cb97.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6eebd0c2fbfc85f938b10535855c388971129a28.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6ef5803b33d97db72eb8a8528aeb3fc956a938cc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f31b3345893eec8ed1ddf1d8de2512b46ff6187.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f3d098f8bb63133924aab70d26a6ed64018c13b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f8788c537cbf6833c58a6ca15c0a36de33c9fbd.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f88527a2cdb5adf51407f4661a254bb32d7de23.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6fa6478cc27e52fd9511fbff38369c921155cfb9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6ff4605d82507fc4bd6e96095eaee5173ea41973.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6ff58a5186d69efd6062f3717bd315394ea6592b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_703246f1f53a988cf252eff88bdf814bd382d3ac.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_70586668a61ab88bc46b763df8f1c2ea52001ea0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_70c8e45f6ea7cf5dba9eeadd0b19481d9f5defb7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_70cf755f1485c065222be4daab84283a9c3d0eb7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_714c5369aa848021e020d874289e3ae4e0f74d77.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7177f939ac3dae8749cbf4232dcf04d2cf63b48f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71a2d046629a4b65c90d0e18d061c4984062f844.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71b6100efe30d836dab557ea4ac54c4b9d35c6aa.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71dcbe9f481c92215f3b636bc0e86ce8f65e6472.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71e3980331dc4bcec6ab6f4c345c7b5f71356979.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71e5fb3544dafa9da03fd2de4bb9bd0718f6009f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7237ce5f3cf13ace3efc0b0227ae5a8c1fdfce1d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_724d1d4408196d611b2e0535bf8833652acbd6ef.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7264e378e1ea1d4dd97f6949d66f3492883b663e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_72abb25dba0c48b380b2dabeb6ab7efaa706d180.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7309c38fc8a2d5ad6efd449107dc54a7509624fe.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7344f96bed2f56793b1c2583485aa161cdf30379.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7393267865f1c2b0aa1a09a586f54cec98eea4ae.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_73d4901b8ef034590314048de7223a572d61ee0f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_73ec21ed6e040260c4f04ef68ef9307aa86985a7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_741401abfbbbdf0dd1d62df8bc3e85371ead71d6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_743176ecb1f0bc800c870861585edf56f88d7739.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_744ec604c577a27e0aae5b39711a9e2eb82801b6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_745705ae121a1a331527cedfe4d31218a428a0df.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_748a3d76e8ab73af9a5d2302d33e3b1d1b866dd1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7497eca4d1a18306b406b367653622a8d64095bf.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_74ba59d347ce8916a22b40e6f22a3c89e13db4d0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_74d5f2aef029f2103bb419cc982cae99fd1a9253.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7524904ac5a2040c7ea72aef5942212f291a21bf.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_758b211174da0f398b2a093e7389905b4f9c4060.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7596c14b8fee751d03f42ca48ea4f66e87fc2e2f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7597ce4d2e5264bdeda47487d5bdb55a014c6616.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75a310a6eb86e3e8baac7a930c3ffbef372942b3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75c38912947881caa14b3fc7ab7bca317e296dc3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75f2010bf6c478d2f0eba77e912697661306c1cb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75f21e38ad01fade35b1db40adabd75eb602410c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7601e6aea44b96e94fb019501be6b102c6e6a654.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_761bde840c0c8149b24a8f6f264e963c4e9e8ceb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_765940baaaa2ae6ade43ef4c94a220eaa63702b0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76674fc182dfa6329c73a354aa3adf458429444a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76704ca28a4877a1e84022e022614709adabb280.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_768c80fd3ea17813df1bf19a158186834fd00780.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76be322fc072ca19baa82707e260c6eba936ae19.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76f884e9ca116ee47b446efe9fc770c178a858d5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_770ad1eb1b30ad8f1e7c17df486093129b2d5630.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_77200e875e0ef160b311c7de450c137772312d0d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_772016803aa3ca6ebe785557118365f9be7c4339.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7726be8909f631c04d4395fa4ffd03a736f447f1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7728d5bec7941c9b6d5632bee8d67ed92b9c03ec.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7764814a0de7702f0b7b5ce9dede6440603f4853.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_77a814291d8f01870274149b9d82fb75921d6e20.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_77d0223697ed41c4c2fd8830f8df6e5620db547f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7831ce329f2a0812ebb1dd103ea4ba8cb7ba531d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7838849e57ee9cd292e588f587a8079b57becfc8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_783ec08544591a22f59dc12f169b7327b4185a1a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_784c35fee4d372123631312f1051c43e1fa12378.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78663faeb0425f45e8a0da0f7b1a5ddbee5e07e7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7872c45ba170f2782c4b5b75cfc78ac79a4cf157.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7878e2a4d3b96a552e03d1ffc33debfd50c9f7f1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78e1edca5abe1bb3e7aa946eab6484b7bed806a3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78e945db4afa1330fe3978bc1bc9ae99828ae287.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78f7e2a2c08cd87702793f91b6935cbe4c22be55.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_797750ac0b18b48f56ceb4640256e9bd3a36621a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7993fc08ac5c6ce7a2eceb1227f4e3718dc4cf5f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_79a7dce707954e765d97cb22e57d9bd6168860d9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_79d0b8053ddf99a4d4447656d733c2da026b3a7c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_79f182ae021e23869d7bebf2a9b4575bdc910ed0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a0ab620e6d62259a559e329460e46e6e3f7c3f9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a13d62a715fd717f0d4101f787349cb49cbe70f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a242e5953f44316b6a4f6587ec26283ed6cbcae.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a2e032f6500fbc5468183415b6dd1d3e43f0bee.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a890b126da2d8cfbf84f048b779cac2dd56b509.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a902ed4ae3cc6558c73b730ff3949778007a230.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7aa14aa94d625b33df1adfa30ef4d91769592608.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ab03a62e064864e1e9c1cd506c1b2e1786a777c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7adf69b51f0a8cc9ae7e250e60df38758230fe4f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7afd1a756247b15b078d15a39e350a07c22982da.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b2d3680c3578c7292349b58843aef7a82e0087d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b5680f97836be4a369802e8115617a83875703e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b67045d438a7e4b8f3a313a5df5a85f351c1be5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b7fa76609243a8709f349ffc0d9d88157f28dc9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b9a3bf1a9b37e0bd9bae6249609e5994dc0dba1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7bb7b63e8a4c1df4eac4d978e166867195bd6e53.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c19fc90e5a9c422dbf529d2def286f47dea0f50.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c23dde1a386436e9864c8fa5f1706c0d2fbfd0d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c3d8ef4da515960bf40eb1feb04d21950ad5ae5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c4710e8f4e27fae4ae079f1667c3a1879cb6da8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7cbe4562c51d6829ec5942e11035c452fe318b3a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7cdc419d4248dfdeeab1f0980aec35fa134e52e0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d08373ace7087bdaca4ce8b0bc329f553f88d77.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d0f767c17385eb7d756cbe8ed444d7cef72dea5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d12e9cb599d24631c082e3cf65d2c58b6d4d44f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d2f87c021e0b6a27b2d7e30351fd50f06414b5f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d5667b27f15a06d4040354fba3601d48bb9c045.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7dac5d4cf103d658e129673549549f1276f134e0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7dd260849b86c46b685955cab54ba07d49b47954.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ddd621da88c57798db1e689b93b692b6519ff96.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7dfe21ee27f8a0ca0407ef0dea73cd73ae6940db.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e1bdde812c332c9fc58613698568a04771b9fa8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e332a6aeecfb12dcf70c69157fd3137343fb9f6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e6129eead18d13a4a6cb9550384fddabc7a2a16.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e89f79217037e361bb0909d06534e40f5026b4f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e9519dd0d0f940fd5efd61bd32df7528ba7e3fc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e9c7feb747241c9c7de2adf3a19933a1c4c0995.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ea9c37d92e344f3cc58cd4d1d00f19167e3623e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ec038393ec329a894aee9bbac078a40f57a4684.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ec04763d635c5bc3e810737b5d948c59f117d5a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ee953cb24e28bcdc8f05783894b23cbf83bdf35.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f6ccdb3c2d595fffd05bc5e6417b157276547fb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f80d44e82e601dc48d4c8b4e710ef7265894b6c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f9403cb91d6aabebf081afae94a8ba397d8d24f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f9bb3486fee7b7c9e24300b8a4e4ce88a11bfc0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7fa76fc1b066a15b08dc6c24a7cf33a58b4cb6cb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7fe409f4421193fb48a54aa5f26bd6229d23204c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ff65c7abd9b0d8a2df9302d6dc167637b3a72f0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8004763f674dfb3f14b66dfdeb2a046e413ce2cb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8007bf7ae1b71bf8ac4a793aa519ad333aa7a7ba.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8021fa266c77e6b5bd1af2a9c22c686e5a6eac78.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_802b21f9588d72c3c3e3b9a3b269f19c484d5aa4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8046f566fa7188c92568b277354e8b06ad382544.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_806f9ab9baf631df1d3a8d801e4cf93a102526cf.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_807545400aa6e70ff49a5f38ed6a218a180bd87f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80987e2d765efc320eaee813607c94c80ee35aa4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80a72d70d80b66c19e85daa00497308381050048.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80bfb0e6032892cc58cef4dd403f305a5b76851b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80cf0997573f4bcfbaaf75e40f519580a7495a17.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80efc341089a50ed5669b3c86f6ddd9b124d1442.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80f51f0e178c33e6196df1d2e47bd38bf5391cc8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80fb694fce7b4c3c459fca43c89c6002fbfdaef5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_810dd4e870ceda3ba9b5f0084a4b025b2e609d57.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_811db756577b61cde9fe8279d956980db9ee21a4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_813e60e8405aca3f7fbed19452ae37574ada9a77.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_815918206483d2ae04a45aa67d69dfb986587214.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_816c48e129a0235cb3a19124ddb28cce286fb368.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81acf1d17650712b71a499bb66909bfcfcb6aecb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81bb8f13b6f20a72c9ce6d0b53f81eddbf05f1c6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81dd3ea61bb61de02667b14f5a94198f48c7307b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81f6c575c3fa2ccc7e65022f1ba65c8cfc16541e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82048cf91270631f98ac37dc488a1fb2e00ce004.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8250f27341241086515d833aa53ae873d4ece3fa.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8278845045d68027dcf3bf867ecde2fb12ec51d3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82ad0c0580516485ea432d98f53e73f6dfec548c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82c932e6eaaf44861c794539d9caf8b50192fc44.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82d7f61e6313930f063758b61102e7a43b118beb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82f0f3d71108dcc49234a258f0f3b21ea2123cc0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82f1d7e1a93bf2fa80c409e6827ea88af56c44f0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8301bfc0394936a68fa0098580f06e77c88ebed9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83080406598df6bd3102db70a554e496e29db96a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_830e3532f27b391585d5de90f3bdf97992b67651.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8352031044ef2e4a22e27ad04ab5d2c02121faee.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_835a906031a258c6362313eec783678bd8125c91.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_836a308c2d2afd6e0dfbfda61984b631c4ccffc6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83d580a612af85533c87aecdd7b0345c71b75980.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83d920a76114c63156740ba5dd6f3846c4b21c28.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83ddca2c6ecbba4314c434e7471ffb8fa642f936.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83f6a1837a65df12b7c55d25ca28cc939c2a6328.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_843e7888cba5f463d19fcb71aaaab25dc3d2c09d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8441910c34830ad2459fb85c2c14af02da718fdc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8457ea5726149efb8778e6d90798b8e48288fc9a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_847feaf237911478173377a501ee19ee325b012b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84cca7528c7d1bf49ba79625733ff0ae7522c096.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84dc4af43de08130a04bfa06df9799b6e9e96900.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84e8ae99e184013739019c93d07caddce532382b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84fc5e94f89d6a9287cf64662a372784511468dd.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8513d96a66a4d9fb8dfc84afba7e1d8c200248a6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85156f2c556c6ef6180608c361b7b35ede71ffea.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_854c8003a508ed3f8cbe6967c4ae2635a491c721.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85908fe6dc9c629c82d6953081b10021e64583b1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85960fe542635079de5eca3c7785890cd4740005.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85fdde4b25e2fc8cbdd46c2850c19eac8d9af8f6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86309c036d96367939ccc3e8922595ac35a3e179.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86513d6e065a44bcb0c789eed1e7e5456e800ab6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_865eb90b1a2d64acc0f6fbe1d807c501fd4be3cd.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8689126a7eb09d81baaf8f99dbff8932fbeab3cb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86d73393d0d8b769f30222f7817563a955c36dfc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86fa51b8c7a2f3fac5cf4cd2951ed2ede5c35450.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_875b08ca602fe48840c72cd61798acb98540fcd6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_876a418fbe6183d0392b7a7d9986d067e323e2b9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_877e33463b3bf1853c6d2d2009af8d27bf88abbe.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8793dc3217e154b65ebba065aa10ab4dc2374ae8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_87e3a06266deda093bdf28af82d8666066157fc6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8840e8899b4e632714632450bcef001c6070f955.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88ac7f6cbdfca2e397bcb86af4216e87166601c7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88c04463f9c5ce565a9daa8c22e16de80fadd707.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88d52c5f70abb525b9c8aa8fc1cb3997c33ed67c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88ea5b5346c87cc4fc1e841c518080df4ab811a2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88ed7f650c958a644c8031aeb88688b1e42458e5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_890aa875ac13957f00b30210477924697abf0c9e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_89617bdea526d12d6a33ed42b9b0018c0b173722.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_89a3327da9a3411ff1cddc67eb647083cd947a92.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a1fd28acfe85b3adac859c4bbffa4d28fe634fe.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a58d4bca33c4c0e79141a56688049237d170d1b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a824621a50cdc3cbadc4b1f9ef18e1325385082.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a980749c6b2a18c80426dd189e5506334343ca4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8adbdcd28cb2f078f89adf9aad2b3d4a0a477823.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8b17c082f249649eca733a8f0cdf9a1205c3e3d7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8b9043572cabb65435627a3faf23b18d039bbcd8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8b92990df507e82f96eeb7aa3ec00c01437566fb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8bd1a40b12ce927323594fcce61eb9c20cc5e3d4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8bd7b8c63a51c8639b3cf27ad09d41ae47c480d3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c074afcf33e3f3534ac3577484237fcfd2ca48e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c13c4f3f645a2bb475eb1c55ce1de452f0e2332.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c3bd4e029bba76ebfc79e6522dbc8ca0bba5dd2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c4688cbd23727dd0ea9a36fb977b31aeae98d65.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c7970957024de050748d3e31cef434f582d968b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8cdcdeb845e7bcdb89ef70ab2a97157d4db3cb52.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8cf1007430da272174d3476d042f398627e83512.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8d079c1eb36db8461fa8b861c56760afcd97cc34.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8d7549e66ef309e32779ddc2a1f14e79bae53754.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8d79fe8a600c3b4e0ec9aa510f8036ba2b608985.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8da8285bd6182355e3164cdc5a983375cdf0a61d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e1b48a28b71c7f4c78eb14321b39951a7c5e903.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e2c587db8bd9f1b551624e0cf8b67a90245d7da.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e2d5f979fc4fbd0991581a020a414f9c8656ae2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e431313fe082958d31b68d2fd0d61df0fe56736.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e50ea8dd480012cbe10be392cd26d1870e6ef9b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e675919a6c7758cbbeecb83b7ac6c62f95cdb46.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e812705ae3e452810794fa7caceef2ef6066dfb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e816fcad5e9ecfca94a6491eb2274bcc41e558b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e938d0e3ad30db201880642e57758285b2ec4cb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8efb5fc2ace6839eac741c5e6616665845f43566.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8f607ee20c0d92b6dbd0338f139517fdcce98d0c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8f6e463eedd3e65b9c79feed3cd92ad8cbc9f036.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8f7166d4bb0c1c9b9999ba16a1adbf09ebfdb6f1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fa4c40e244b412a07933d369704bcdaa6d5e74c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fb224b40a7be7db0a9c5c08cc5ab05b526c14e8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fb33fc20f2e85e915f1b1529ae87981dfcaf86d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fc08b4f3959a2375ac03f40c4ce12d70cdc2d80.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9009b7d39346537aa6c4a4e46b81139f603edb60.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_900d7f81c73b35ea64095d01c5d48d9190839e0a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9068ba8df8b0e977e9769f6acf6cfee6b00b9922.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_906fa8bf5e992ddc25815486ae9c24d8bfba7227.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_90b17d8cba28cceddb3ef907df878aeef0762d15.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_90da0d469cca5c8481504148468460c85a15c559.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_90e5c56e92712d00092ba102a5eb5176a3e5d471.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_910cb8bd09d287a1566265eb1e8894fe68d3cc81.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_915b75db795dbef037b14b003ee073665fe35d3e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9163ae070075f26926a86d39e15c27e6edb1f1cf.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91695dea4171747fb3cc6d910459f800608d07c1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_919ae177b7a793fa352c4f6bb8e4175f3064d814.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91a6200e36944b1f11106c02f7fcee053f01ee71.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91b9e2616c2fe0480096b1ccf0f74d584b220146.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91c916e14198f6d18dc89915e379b01070434e91.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9207a63fc55c411c73e4f93306c5ffed800dd249.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92121fd448b4640a17e1a7fe73bb7b58714c0afb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_921f789d619db6f225e8e9d646e93bbc9dc1a669.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92739f4464512feee083b875e11e11eee4f5b448.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92992be6252f2afdc368bd4baec4b8a55ae0abf8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92b0770fe64e3c60b9e56170aa88bbf74802a813.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92b722cdabcfaa388ccc6ccceb7e42462f3bdcd1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92ba64cdf615c1be2865f027a293cb530fc07dc6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92d841e6d783bb46d841aafd9027f92dd1b61b88.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92e53359c69bbe4d7405d45261a8a62008eb7d06.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92f9ad0fb65638cfffb3e7786f2cbf01d9585b23.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_93054acb8a9508fd0f0f486367fb62454de47c39.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_931cf8d05cfa45319f4e5bb49334d35a530bffcf.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_93728d999ae43ee1b5a16e60b90cf8533c7d303f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_937801fbb43fb6797f0425f08d13926b74d87c4a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_937c48d0b7096ad6c8bc445f13f2c8c1934695ab.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_93b885d6869400b0dc2ef1b2c2636ddfd21cde31.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_942439e4f5644a3a4630481bc7d98834b29b6e1c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_94a94d145e575747c8956ac703810582c819e2e8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_94aa519eb57e5797125728492d9330f5c0f0670a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_94f6f9dee9f0c3825d91f4d320a5280070e60ee7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_95061acc6650fc7b79fa1fe5b2b1e083555eec2c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_951343832a5bfd060c8d12da0d8a090f070a717d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9545f95c1093c60f0fb6c794636f79aaeb53b733.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_95530399ad7b43d8ce2c89da24c71056f2146b18.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9583148fd684a7e6a312127e023798278415bd27.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9594816877815bc0294610ca24f986fdccdc7c6f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_960ecb3013071fb65f2d5ed4c947c4bf303e5308.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9638c9618dbf2af119e37596f7eb0fd3f8d72748.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_963986150adcd6e1d3886bacf2166de1252e14df.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_964f916d3484295b5918e2e4c22c5529588a5662.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9689ecd7bf51bcffe9f5002959bdda41c50a3c8b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_968fc75a7d102aca068e3ceb6111728c280fa837.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96c129dd4c798343d6f78ab78056f0faf2f1c9d3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96c5e79f54b71677124f555b0ae4bfd27248d099.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96caa2056d99eb67ada498e287b4fae984397691.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96dee49ec6755006d67f0c30c65f50558bba69b0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96f1bb85dff8c97846f6b2e8796a6289bcd0d9d3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_970073c70133ff2ee4737f803a0ac43801c47242.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_971a08c2e48d805b295d979b24173a04cf58def0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_97246460c21bc66c0f13936d27477a9fca1c44d1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9745b04a8026a01828c5dd606d89d044d3ed1d99.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_976cf509d9c2bf86ba6ee5ded544fa8e6717f590.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_977137b371df841993c8d0584be7d83aca6add78.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_97851d5ecbf02f8af623988b1a39c0b91e51533a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9801b25e0f132d647934deb395b62a3f70cc7c88.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_987a617fae00fa90a1ba60937b0312c81087c19e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_987f00dd759d9714693e7517dfaa8bb427294d42.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9893336a4b00b2a63f23ed7e13ec54c82d9e5063.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_98e484adeddf3394d8d7693b808d83b64c71ee69.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_98f5efcd500ce6b9ffc14bc9877e0ba457539925.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_98f9a4f4d85f292b78123599a2e1798f12aa545b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9990e6ad243a48b84304b5cad0c663c0802aedfd.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_99ae680eed89ea93a3a94586bd5a68dbc5439f37.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_99e2f290b962f1617b0a9d4fd6d55c43e4439d6f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_99f8352674bd6bbe98944a1c0a769a4fc028a623.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a0a70932bd587759df1e5e150b25b0126d7b529.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a20fa19d8d30654602e363806f559113218d66d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a8e04fe9432a60f86ff0369e8c1851821074a04.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a9edbe35a8fac7796f00bde836bd547044770ea.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ab73ea77ec20ea3bfaf995dacf93a6960ecdca0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ad1f99284aafc8d7908d062f179a056eb314925.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ae866c7db36286876818bfb718ac35204fa3843.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9afe4b6f3b901ff4af81bd4f1cd8ff19f09d0b07.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b062dd633645772e4f2caffd111af73184f7657.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b327f0fa1155f2235d76be45cd22e3db5a69429.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b4dcde1ae3446b825dea739d4295c1d1ec5c4be.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b6d08e63b9a90f2524cbfa8c5fcf8b82a1d2d36.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b73c92a13757877f34bd8a13c6fb29b60999020.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b841b7cf5da31f0c30ec42c91cc8d5bd3fedd03.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9bcc791049e3ff9ebc1a9085d2d20efcc2f99b71.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9bf235679af1ca03a6e601b4cf6cd0416d1c9091.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9c4fc7cda4b560040cec93f63021b529aa1ee3fd.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ca3b1d36d777213eb381b47871bf15dd163c994.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9cc3ef3d3b36f52089548e9dce522b0448e2c26a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9d3d274058bc0a3d4d35d90669587761fdfbdba1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9d6759d8855c4c6289f1f241a1628cf0406c1b64.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9d69d441f48f9ea346dd8e00376a9a708da3ad87.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9dc424f0e192155e3c4e786e5b87d5a1a3e6c4ad.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9e51083e13aa4dfa8c969f8f916835a8e5e9ca39.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9eef1b54d5d3841f3fa6b84cca6c7ad33efa2d9f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9f0517550c7a23882b95de451e8099ea2186b4ce.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9fb389d4b5ba590baa951f17da06f0e53d2bfa55.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a017be7b8bcf303b30a147f41346898acc5fab7d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a02a71fdd587e47ee68e0cc76c3c4494ce06c359.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a02f152e9184af0b3d77082d8bdf519dbbfceb2d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a046e888e3836b0bd3c49fec8e1872e880798f0c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a0874fc5ac87a1ec487c7722bf3b1bdaa924ee09.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a094599fb5caf5e7aba728cd4713a8d0c6368a46.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a0a556c9358ddd6db719458c81d2d6d822a895da.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a103cd47156a98ad2cf2c325ea00df3f1d67fb72.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a189292c81a18d21a2921ce6740f81ebf4c046ad.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a1c71e7d33f0597fe090a3524e33e18b2e562680.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a1cba1509c413c870c5d784410855ee1bd737da2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a1d6ad9de7ac7993ae1923a2ef070b7dacb8c563.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a20c91b2f11bb7e5058ca7935b0bda4f5558a9dc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a21f3637624762547af1292e1b85e640b1d329dc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a225c4f1f3c7b271957768bb9235131c67afb48a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a2482a64659c838f3da55f56e3cbbee1dbfe6722.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a25e2aed617e1ff31f93ae7e054313ee0dceee97.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a2a715b7e9c1a576f011dfe5769c5b392e984f82.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a2ef5d30a2318ae06430d17f84878800c4ca7364.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3339150d8bf9d073827738527f6cbe15b854607.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3709e4fc53d2254a03ea7660b8c72d2f47cf1ad.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a388a284f45f711d82a6ed87036d87cef1872eb1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3ac4f93722dc314086f1b7d7b8adc687cd75f82.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3d7aa46528ee74e2bef1e87c1feceacfa55e173.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3dc780b17152f696f9b957432c2eae8fb16e85e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3f9c236d24b30bc9c3fad90cfd6eb00da835de2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3ff8445ba691807caadd9f26e7eb90851875280.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a421c2ed6b295c458071f1988b9d6f7b46e8992c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a4700d87a19a173e84d64e43cffabbed52366e35.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a487f617c4b84c6a0328fedac750d41dc3dafe27.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a48843d844f78690c7a45b730652f0f763c595c7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a4980becb0d3149fee575bad1fc3b463d08aabf5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a4b7f10440331a8a88ff93ba253217c2832bcf9e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a55b47aafc4340e69e300ac61a7601a5c14513b7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a55c7dd576e5b1061c059e5e99aeedf4389e2d25.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a59423c095db052603d77073d409534bceef425f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5a7833f4597bb03a3e845d5580d677e97421040.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5bdc110955c05c6c6ea236a6f60266a4a6dce5e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5c0109313de1f6245d2a80f8539485b849e9d55.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5c4dc0d70c547dbbfb661e879ba7f9adfafc2ea.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5d4eb673bafd81e3a0ee213da4603d88b8460ec.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5e5cae764142683b70d3344cf07dd1edb7d69e2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5f2f0cef657ae5e333d65ae4ab20529a43cd7de.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5f8b7b2a891aa9f2ab49762eb31d835efdf18b6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5fa94bb32a80e81886b711ebfcf2df5f5405866.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a622fa57764ec746e02f6d4bd4846b48c722b807.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a62a2ab489839ea1a1bfd1b24e54a3c232ed934f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a6461d72fb6ba50e81de3f661528c96dcfdc3f3c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a64b4cf3f6706e4b4e0af4402e2263b9a1585f9b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a65c43b870705c780d734f9ef063f55cf8b3b52d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a673f35edd69241c6b921d6712dfd064d78ecbad.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a71305f191f06cd53b7563971c706e8b71b19e2f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a74b0e7dd816ad08eec5a1bba6e227afee9813ec.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a7784b03ad757d51c234fa86ea9891f055ecd5c1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a78fecb9725ceb4bcf2aa037d43bc43efeb1c3fd.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a7f7553a7d2f6d42fe695cdc64423c85223af440.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a821661d8280c6e9d27f2c9ce1b3c855387b5a76.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a85d35b2fd98742427930eb536e346ffb005edd8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a8a4af070ee46d802cb11086b93daf91538f8a04.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a8a744edfa3a19d1493611df5bd0d4d59b707d43.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a92b43d374642df991edef1f6036dc898bf77cf8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a93324ccf11b273ed20fd960c61df897c8890b1d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a93a03b33305b33055273711ab31a5b8d8298d5d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a968df29f5ae1463706b7981b3bde55918e1aa65.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a98925d99dc484da41dd55700e151cf545cf821d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9b50c6ebb27986ce5b378d8c39315eb9cb91dea.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9d2be18e2d53a5144f97dfdebb225fcb6d611d3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9df9ac4ee78e5f4d5bd0567e58a7090907c61e1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9f00f270680de81df7737e848e0408cb070e68b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa1041530f794c7b8dc4a8321ea0fcdd338fff35.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa522b43c5e5ea69bcabb4c0fe28def2bd081a12.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa6d13b09f85ee62bb5018608812181fb43afc86.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa82d20635e592edbf00439294835f6f39ad54a3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa996b9c843200a2ec33ed4319b48106cd7c6384.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aafe891dad43815e635f81225705ff944f990d75.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab09941bddfa9d61985b55f9b6bf0edec9bb89f6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab0be5a2072b5e87f5ee58149688796b6513219f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab0c3fe9529e24327686070731d0ac3ada76245e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab1ca4ce061f7f69a250356f613cab00d1e2ac71.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab1d7f93427095e39bfc1d986b3d7fe54073ec75.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab43f4a56c166dad0113f51b337a083f4df7cdb6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab56e886d53a1d88fada0f10f00b9f398dc54568.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab6cd5c9242f8278c8f3d9ce57b97d605c7e5a3e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab877ae2a1aab04498bf2b26b3fe99d6488ef151.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_abf6c6412f9853855b74a96e862935ddef66f763.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_abf92a5314fd33491b5eb6ebd2418b7e0d5db774.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ac1ccde31b47e0e56ee0daab6403fed7895208c7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ac5e9aee85cd16903bf7b82a4ac10402b0b26e22.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ac9382cf8bb56ffd962c99329bf67da992f8810d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aceb0641213e9a45ba48bcf72bb23845720d8b79.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad091c69d19b27f7ad50ef6311532ad8b642a9c6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad82071cc074fd30437f6158b5eb2c6df1f8c587.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad989d2ce769f20e175fa88f4082c1c25fe03062.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad9b99a194b59d3149842c15733394da275b12c0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ada016be2bd0e377fbe01fa7adb9bbb8febce100.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adae2d4f8b2dac799e03ea6f279e6ecdf66f5381.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adaef10ff2c5d89530310bdf1d53a194f06a94ef.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_add29e3e9828911a117dccaa5650e77805730d14.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adda7ad787524e3e47dcc1b65c41b2faea38f55f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_addb6a14043c5a4df0f5042b3770b40c4e90795c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adf160741a4f751d2f15d6eb23d4121cdca62b55.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae1ab1f4bbe86bb9bbc22e4774648076c321136f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae1afeb6cfdf860ff08e4c2f11c922fd5bfa621a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae239476d61f48379754b97f29d7a285cc3192de.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae4e7253ad4873576052ec0a9400597bb7975753.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae4e80cb185759dd9b3eb3c67c239964b3694caa.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae51b30c7e1cd30e550187458350c8db7c59a9ef.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae7899b1ef159ecbf01f27014601eb79b31b49b3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae87b1d5c50606430b544ed650d87df24366e7d5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae8d0bdde763e617beafc0365ec4a3cd11df6c55.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aebb2441e6cc1ccba4a391566e547402bcf7ced2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aebd5fed34ebceb879ae3dffaf58c7c04ab5fe80.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aebff7e6605b273bad844b8f70ef031625bff48e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aec87e65afa93e84d7a947c52f291c1c7360033c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aece14f7a220222eb4ce6783ec2b9fce6fde94b8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_af06c0dae15684f83e15722a4c07342af9ea011c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_af6ccfa11add1ae49888337e84d9c446d2f67da4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afadc4f76e237514db0bc0203102297b79730bd0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afc4b47a6fa62a4ca5cff6a7e01c9f6b371d2215.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afcafd07c1f56e74373ccf37db35976023456d50.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afccf699f593c828e11efc053b144044e45b32d6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afda8f46b5ded4c2aa9d722fec17b75004b59f7d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afdab954fd111ec48721f25710d61c0c8affd8db.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b00e062055933388e37525df5766f3c14cd3538a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b01dc872c24db4db0c9179fc07e17f41060390de.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b03ab68e33844f97aa58d463e00037bc11c50da0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b04f14f829eff73afaa57a875f74ebd1e6860979.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b0544a38dfdf4d81dc95894387845f48435e299a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b0dd965d5d9080ed5c6a04b7eea9890f3a264f20.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b0f555b74ed36f1bef8f47880b3edc6760f27788.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b1766695dbb790bd614b83dc7569ad449404cc89.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b18a615e66d7cd739ce35412811359a03cb23a8e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b192c55f002d8540d5f965cc4df0c2e33f4b9ff9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b19f05f6848403480ba41d37cdbf44ccca1b1f8d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b1ad101ce91348266d3885afdf2996a0fdb72135.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b1c5d55d47d6038e9162d32ac968ff58c0942938.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b20c6252863a73341b0010191fad4c834860f884.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b20e314642cf565e4f32bceffdb5c0e653ab627b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b24f91dec2029b25d0d96962528410df55a468ed.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b285e2f1970b78e18002464eeda63798229bbc3a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b298e213f927b518c693660110f08bdd94990ef0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b2af5f5b5ee3ae964824a3e9c7bbeb5bb39c557c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b2f91e937b427ecc932c0cb0c90b2c2378db0be6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3063d06723ac70c5f8802ab49c5c35e1debf56e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b31f56244076c501cb09b4b90975132cae4c4386.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3486244e0b7d6dbcaa1951e8b8883ce441c3f99.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b34c1ce348c3d9cdf6bbec9758de9d5fe94c43fc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b38a1d3cffae01332a3a9d9472ff1b2c443e82af.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3a104733f678193068d8642d6560faa03897258.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3da22d3482738a8474ae15e8e5fca9020c4e195.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b41735d250b5a16967281a5f07873b9cde3df4d6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b41a30092e8138877c1f6c25656e0f8ae2c2444e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b41ea5293bc1c56efa2c4b5681d965aa6f2ce6c3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4588379eaa268d79fe8f8e4457b009f204a5fb7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b493c99888d82cd2852bfb101f99a2e6a27665b8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4a5715b550f67b8870ba66e1e6282a26cc1dbf3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4b037a2e262d11d3ed7d9feeb41b9e05427a739.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4bd2d206ceb237ed2c51f58abb5cbf96e39d07b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4ec377c44ac18527ca6a01bc3b146706a6e1e09.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4f12f10d7b968e0d8e7c23f36d3a360de74a905.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b50e6df20a2426abd3d2ff2262a37c009196024c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b513834918d5ea789e2db21abece7c2d3532a7e7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5248f443a12d96815c04409a00102923c717023.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5371415448fffffd58bf014dac9f4876153657b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5ac596c636df55e81293228cbc53dcbb3024e5a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5ba2e73df35f6e0f7317303823fde92a42b1a35.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5bccc85f74f54a2ceb17fe3040b04fe306c53f9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5c3131fb8e5a25bd4a14bc9075eb6fa01b61d02.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5c7fca1f76a31b0390e92d90d569fab94d4f783.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5db3d5b1d8af89381fc4b8073f84c5fa25fdef5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b60a4e87a7aabfe3c1ce02b408522f3ec862e3d7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b6b17ae67adee9e56a022cd2a5514fb9c4e99920.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b72a804bb3c99830653d41ac0bd49943c801b89a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b737410b404a51043fc3bd503c0b107c297e4c9f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b75843bb13058ffe29251e053800c509c7590544.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b774450ebadaacf23e944aaf8ca90eada01e8a5a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b779cc0b0380e1e6a2b51fc6216fdd72215b882b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b7a03ab0b7887cc7ed0cb40e56360a8d36c0bb8e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b80d0828ba6d24ea3c1a97bd9835ee937b4b32fb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b872f9e6ebe330cc1818ea82b53acec79a2f672c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b8fbc6f6e9c515edce3c7a438b3bc308b30d3857.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9385db12001110c42eff6aabad935a69ad3afe2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9559dd36a0a4f5e068a722e285f485137bd5ef0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9627f9c8d0088df0364a64643f2b5dcd951f2bb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9a742ceeb6736a2c8f9439d0b05e10d3e0c5c6f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9baf70220079e6d4e87eb01a7259923d8a01e29.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9d00ab8373747a5c6b9d2f8dd50ceb14db4163c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9ed0a64deb55616646ea98b21a891c971cd98ad.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ba145535e53899fe127987aa854f81234a9c51c4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ba8b09f0aaa40a7c9ad5f0458b460d3e328f3c74.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bafbef3f13d429ec3e9f4672218998d5669d79f2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bb111b7acc269f8d5e70915d3efde4c425aa5f5c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bb28a4e95723e3df380f98b5ac107c4df353850b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bb35c86443cc9ea38c06ebc0656306483c95ef67.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bba10ecb79ede07324e1198a71a95ff26e9eb235.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bbe23201fbebed25781f249e5c77c31e0e7f9ddb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bbfd025488e52b97c04995c4c5faff371b77e4d6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc1ae1dddb8cc5d78196da6b26ebe66c1ce7e567.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc238fd2095b26a167b41cdec8280182330b7b25.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc4425e30a0b17e8b31726817e8d3177b5c51934.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc4e0f0496a34d2fb43c80ce0162ad4183f29064.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc6ce17223d8d83a64b8c96ac88223e4441a4692.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc744db85d4237ee9640f1658e0caab7648e3bb6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc79e255d25744725e2a9db9f90d5cc2b8a0e0c1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc897852a4ca992961843144f4ec4f8b86dd5e9d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bcb6f0730fd09b4c6c60913425927dfdb8f83d82.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bcd7ccdceb7baf3b986f2a0248827822a5f72e47.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bcf8836c8cf932cc2748e313885003f0e11a887f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd064e302ff5b983dbdb4ccf51383fb29ddff44f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd28203f47b6a48e9b66302cf8312f3796ca500c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd37f4f7914805a97d5073f1ebf8a8b8c2648d31.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd3daa5f99b4522d932334924347353ce2854821.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd6aa39d0ae3c87d011610cdb5e2e317f337c454.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd80a1774d8b7d8bee4e8663392b97cda11dcbf5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd8bf7c572c1984ca3061062cf3c31d993f6762d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd9c47f3305e47db6ab6bc627fb3d80269633074.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bdab172627718278a71a93e3737ef08ad9259a4f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bde24a8dbe6add6f2dd2beb48b1280f3a84a9b2a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_be1e1533fc37b41838bd37edc2b6d2f2e76ae1c6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_be4dd90ccb2f258029d0156cf23f940b694cf08d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_be8ec1163a01b9cd9a802d8b44669e8770c20234.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_beae876d6da465687f162136231f15767cc7bb14.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_beb9afccc15de7dfcb2e7d898abc0d61201de73e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bec30e7107c5dce3fe6aa87d83ed96da75478da0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bec9e4c0317e8d351f60258ed6611fbf365c4024.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_becc2a4d7ac045365300bf8bd45fc6d3e1e1c8b1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bed5a8c5cf683f6dfaefad72c2e2f5c2f2b2732f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bef3bd014a918feddadc98eed92a7734f9bcd890.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bf9cdf86a7944cd690b0fcbbaec235863acd10bb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0338fbc05f86270ded7df2bd3e2758a03961b62.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0342686e4efd26413c6719782ed13603479c4e0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c063318cb851ccaa923be12d34c84d839bc64bb8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c08095341ca7e3a1debeb780c1878e351692bee2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0a3c4ac0a50bb9b7ad764929dbee98c856b1210.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0f76aff077c28f8afd7b22f284cf2894e08a043.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c112c01d201c366bdd7acccf2e1b18b00f671153.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c11d68fe766fc753c657362673704005b538660b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c137c03bf161b2ec6a9a046fa49d7bbf80ae47b8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c197d1f050f42d82e6851fa286db6f81ba197f40.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1b76bc7a17f573c0d52c07ae9ff4302662ae61f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1b94e19d762ddc33cc4e94c6675d93cbde21e3d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1f40c3421b9ad8cf43940530ec50bcf620058f2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1f721a330b2d0fac13b22061616d7b10c0f91e9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c250ea59ab6e1ee39cce15cbd3f181047cdee31a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2541b6b5cf27de3f45f60671d36602f07ce1783.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c27b3026f1dc3056dee3a3e64bf31c45683607c9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c28de8f96c8315877031a2d56261e95fee6aef44.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c29110dd501853e87ebc122dd1971b0bb1bcd92f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2940fd05efd52bdf8a3f9aa4b78bde9b5809b34.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2a2856bf9a81544a30d535a13554e3a8107c476.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2b719893a4d8a1e71857966d399f06c0a41749c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2f04447e6a94c94a2315454e71d7d607a9fd0f8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2fcced07cc194a8050bc7b2f791453b3f5b2064.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c323a4d1f24d59bddd20ed2f2fb6446627b0ae8b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c355189ade9b1a8269230232db754a3881b53168.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c35ea54eb6cd0f3756c462c66d9be956279b46ad.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c363ee1b087f6b504a3dd3972b96e77db02b0582.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c3cfaf0d53869c373f6d0ec821b008dbb819141a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c3d0eaf9399c863d672e8c08d123739bab837d4b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4015f0d0a7a5173810f6f17c00065e03fc61a89.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c402e84359b2037a29efd1d6ce7213ba7605ab25.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c41b6eda4f250da059fe0c428428219ff5a250ef.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c42ab428503e8f8bfa78c8cb8d9afad9f5185118.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4376ac8d82db1bc25fa273a80dfbf8b71ee5e2b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c45a5e40f6a66bc5292a56e0097c69fe37cedfb3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c487a1a9933239270f44b1e08e1cf5323521c089.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4997f79435cf64add10506acb97d0647cfbb3d4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4b34d3cb673447773f6da23e9cf52b98e99f718.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4c3425fe683d35dc3335db77d183ad1620b7a92.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4c6c405cefe204824e8fad1b3dd34bba87e796a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4de1bc135191f3c2aff740f4c6bb7e98da42f84.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4dec99707511cebd9188d216ee0a148d729b470.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c538dc4f65d02776875627cbd20a9c794d70b043.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c53e295b68e807774ed31bb914e4bc59312a77d7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c56aa150611b0d4800470c1493dc907082a5c23f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c581974c8b6f43f60d0af29c350d850b55c03121.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c59937be2b9a13d6520fdcc922e4e75c9fa085ab.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c59a22c6efd8bb8815887325aa0b739e260cc754.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c59ab718fa23f24f09a713ac28a339208a7a5802.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c5b440ca9a5196ee1e72c878c87d96934e9273c8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c5fcdea177734366d3bf283317a65cc3fffda611.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c5fef330a975002ed15670e8e7b26a10376d3cb7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c64f4cdce32189065362a502105c31bd2d9d99a4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c6e2da8b791d31f4ba05ef5f833fd6dea9e35f1c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c7568e11e44ce70924d27e683190422cfae5c31d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c7af2bbfac25de2853be344b9f636226c1c0112d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c806d7803d06ef8aac1d5caac9f36aafd47653d5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c80dce1a17d073259250ec0c87ade69e639ffa8e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c8dbfaffc8a9b573f194f9c63f1175d9725f8950.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c8f6461673882d636772ae4d26e78eabcb568f31.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c919b8ed877d4244d01a17ecb948b459e361ff24.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c921a4790f982d48bcaf950123c699647afb739b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9312d7159369d13f3148a6f0882dfad6921ceec.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9530e20038eb40c49bc8b045be0cf4e7e6b4eac.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c977735a36c325706bd19a12df66ed0839b032b1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9ad71883a19b522486706d3705700c012a6fc19.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9ba0a3369d4e4eaea1c902a90e6501f232dd57c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9f1e7e478a2208c4d32e2d7e6abebdc16bcc5fe.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9f28230817c9d9805c41dfcd4e834fe302e1df1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9fb8343e623e46f01893a2b61345d1ca5928671.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9fe51f982abd60e567d4238d3266fb60e45814b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca00cfdc5592b7440d72482a18781e9cf3afb05a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca1992a2634cd6674076611be54197c715ad8271.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca3975efd767ddf7c12e308d948bdcaf0968493a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca3d98ff43fbb80ceb82fc22ab039bee898969b0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca4c6ad28aff1976c6dd36974ec3b339aa3090e9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca5681d4e5871aacef74bdba9e368445875252d3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca920c3239bb5796b1ab2fc75177eb3b820aa784.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cabb7b12cdd9b8b522af577e13232b2459dbd38d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cae6c7efbfc831e2bcfc8c1efa1a486c02627cbf.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_caede7a18f3e3d5e24f6c70392413a2cda16ac15.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb10303a0b79f2710eb7c66896d3c1f8b12c04dd.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb1a0ce432c27f4cfa51731c3ef181bf60c8a727.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb1b91c16e0255fe7a0a85638b98d94634e143a9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb1deea4f4fab0db31d46a91228601f0c272d6e6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb20538073888bdb3174a8e9c32d7449072aa753.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb3d5273945c5d40cc05c2660af2df1fb7a15f3c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb4576e8ea5d59d7663f3760009a00a19e1b0667.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cbd571f4fe576fdb17d5f75a558cb6747087c7f2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cbe5a98163e878c7697e554758ebd0597c2c1760.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cbf3e4d4d4837a0cb33b78c4f2767b1d93da0850.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cc127a63d56099e08125b16939dac82f0173122b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cc4ac5a18f57f2ebb65f7e356e858ab0d59b2133.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cc54b107e1b557ea36b5cbaf7fe3dfce05415c86.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ccac6c0e61b65c9422c7f30fbd979031698370a9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ccd0b777df1328bf24e070ed4cdf8615bb2199fe.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cd0453a5c3828c1358360f31f5d3b7258e17fdb9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cd4efcdd12184211c74e7b3f2f30fecf1041ca32.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cd757a8bbeabd16a44d149ab188430f6d79ddcaf.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cde0582e1aef74f9209de638b553ec0671476258.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce4714e4f33340859c106a3129993e22652262e2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5064e27ba427cb951f7e1b01328b0beb6b2b7c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5ad502dd40353312d561e9f40aa478c16ef5b1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5b5932f6df9a194ceb0d69220fba9596528eec.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5c161b725becf059fb4439c668edd454ac77d1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce909cb5f96a4884caa0d2eb8c5e6bc7fa352797.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ceb9544e2a0caae2c9e3dd8bbd2c509e8dca1379.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cee81ab2e2678816c7b516d2d4c50e8cb5874c68.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cf5c6c0bfaf98f6e655fc443246b81fcc730fe97.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cf73e1fc0015094861ca0c1c81bacdbe0c5b8f37.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cfda56a4eb08b803332f25bda6209932d9624acc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cfec97bdfb6fa95e057eaf5a8138853e1c0884f2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d00f65bc99ca08eba66564d34f72f2769bff9491.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d036096f49a89730f8af7e75457c88cb8ae64165.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d049a1b8f4c1c6d37973ce38593efda1de8ce0cd.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d04dc4ed02eb42c3fe303342801ed3073a0dcb8e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d06ba4c996570ddab77b6ff1e2a0101b638543eb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0863830fc5d43dc6d6400280e892bb7de2892d4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d090b771a4f9750132f549c82a88b4ab00dce5c7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0b09e8513646fbb2a007544a63ec9e2b04dc4c2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0daa59f5dce6fc3965193ae37d8c82a3d1834e6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0dd0165ee91c095a19ceddf08789e3576912590.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0de618ff3ea9f67b90f2227fb7fcc74ea34183d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0f63cafbeb445408c884727b473667fb479675e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d137b7b6e04e1caf43a62bd6788a75361cfa98f6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1840494c4fa78ff399c0399b3ad7ca3d22d4587.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d18727988e47264b42b4153dc82fc1a750f08db0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1c0dfd19a08d61586758091370acbdc6f267017.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1c25cfc437d8bd803860e39a45b2f3b9fa48393.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1d3eacc320104100bce46235fe656e5a8223c66.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d20d45aa85c0daa299da98c277cee826fe67bd27.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d257148f457557ea80ca56690e525db3a4b0ff55.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d25ce4b3e9cc392ceafebc7fe3bcbe05aaad4bbc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2d08c5470a385d0160b2c1441fd1c30fff1c17c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2daccc4b3a0f90bff39cb4597f8b7e484613d9e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2dfdb42c1b380e860aa5609302f29698dd27923.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2f4b869ff23874b6bde0aab68c419108b7e69f4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d32c64ef01aa228277d031a74df51363f98aa2b0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d34d6cdcd81a456125ab5e0875466c6334d8e5c8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d34fcb56caa8f80404789fba0ffac447483a4d84.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3784fb4c0685d7b651f4113f3c71e050881f3a5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3a23ded424200d0c6f06b1dbd0a7b7b0e7b5d9b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3a2edf232786d458e2125f8dfeda8847f842afa.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3af8763f289dace1054bdcb4dfeda28b0aefcae.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3fce1e11aee2273620e75efe4aa0390fcde9ba5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d40569ae9dbd693c0ab3d6ba69704d31e451011b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d41b6a64dd181f2efa65aaed03a3d229b3566c1d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d41cd6b60a97e7071518cbd1a63abb8b910df024.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d43715cce8935439f90172d141050d78c7e76fb7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4605b2ad3e3753c5f255678abc1690b949c5abc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4645b713821371161a9925dec8a3d6c157ba1aa.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4aff499ad527be5fe33b8e92547df57af26d40d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4b99af9a573df50a27fccbec3fa8e350f1854eb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4c9f975891087e6eed6393629b41155deafc509.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d50ac8e8a03f8e7ec2c6e993dd39f09f465dab57.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d54ac01458df3f240e0656d82330f9de23ba9651.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d54b3731883a5f8393d60d27487f8d017aedd3f9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d5e82799f4452e148c3e02acd6526cf30757eb52.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d5edfe3e3dc3008b928c8e6dbd50784b905f189e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d600779c17b7b21c18e1308e6d765fe02a7945d3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d6149eea92f2c40c11de3b778102fcf9b6a006b8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d623b36cc3f56d1001b2d3abadd8a5628fefd014.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d63c8c746055851217a514321cd735eaf6937263.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d64b8b52f4a98801e185e2f132b2f80c29dd0c37.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d66b79c4ebdcfd239cecec58203606bc123bd6bb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d66c30148a6fa816937f2f095802264d3dfa0273.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d703eea8075cacec4d41fee7dc4734f593ee79e8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d712f23ef88ae5d7b161d36f42d22a5ba53b6354.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d713fe25dc90b3511fc259cebf463376dcb55d84.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7145383e39dec0e346b5094401acf85ef3c2075.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d723b191785c97d284675f700a7baeb52a2eb791.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7290cc4c3036c9205e689cbcc60e7d16b97a7d6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d733f4c03e338ea7c6d8f759c1132499bdcea059.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d773df9ccfc1ace90fe3afb5c00976deabedf6f8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7adde8780b39f1364c572a19c3bfb19417678e3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7bda8157fb27d544e049fd7d2ec735725f1bf44.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7fae2c18645d36a181a0bdd2d8ca7a4ac0f6d1d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d82773721479613ad72e334510a248f1436b38d6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d867098db97b3f26e71a151c63b74260bfab21f8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d86e4dcbe9c4cac8f7c8c5d97ce384ae0cbdbfbc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d8901a63986cc28ef24cab012b32114851a8c1ec.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d9061c204d8a85c974676f4438994a0be9d69a60.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d924ee32b178b6bffa7a71603d6e2818f66177a5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d937609afa8e21a761dad6b01ff3f26346e450fc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d95835bc6f000d3a3379bbc38d90e83dcaf867ee.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d992eab7de49033f5480c5e86a69e675db0d2a19.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d9c23b7f8fcc4e4f4c81f5f00cfd345b98df2e0f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d9c3e27b522320dcca5ee84fa534b03aae2bfea9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da07d8b5666423da30a95e3b2cabd3839d200981.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da29a515d14dac02066bcd4701285b9916b43cf5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da6afccdee4107507a64323e17bf12c46da2b92a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da74887afedbd67928fe4d596709f9ff92530611.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da822ea727fb3543e445e4000f7e6ebb946d6a3b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da9f6e1d59132fe96709490af25bd794f267851c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db0d0cf55d90b3f3c9eecada1db93c420f34b1ae.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db5016bff9e5dc37184d2b9417eb351c7ea1c322.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db85839ee8d464c5a81b8dad9839f5e0f4b467a8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db8f0bd93b352d28c5b6d78f4332026993f0bea4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbae1670fac6812b2d2cbad973e4b475509ea504.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbb06b43d5d65429e23cc717448cf1fffb0cfd74.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbc4135fce01e8731fec7a78d0cc0fdeeae28b90.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbcea8f7b5930abf76eecefce92d0db785d2df5d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbde2ef18e2174ebe13a6e7c8c2a6b05a6612047.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc039d422a57c159ea4dbcc867d766ff1b356a07.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc08afbff5def8bcb4e823657ce01f57c9dc77c9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc184767d723f4995791848cdc68bd948408204f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc1a7f9b1afeba6690fdc0d0d1755ea89c805573.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc34b6ef496d4e0d8fbbe10731d4a7b1c136c036.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc3d625c5ad3e871f5a727ac946df642d988b9ab.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc4d27535b9570b8f4b790470a83c1d0a9a2b6ce.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc5ba6d73f331c76e696953606c5b347b6a46f3f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc62a8db637d32e7dfdb2521cbdae6e1fbbd5fd1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc818f3ce244743cb1dbff9aca399df90742a6d0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc91797c1474a368e9cb056b50b4629d7736c3cb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc9e54273c0ea2358fb573a7d918aa7b09fe07f9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dcf815ef540060cc7ed43e1c57a28e1d080c5621.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd10bbf37503bbc92af82bc3487989b41b20ca85.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd11806cd2d3ef1127f676b2d98bf8fff2a1e5ab.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd35634440edb25cb095800b882c70aaceca1dbb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd67d442001d2b167e70e8730abde4d4461b8569.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd9494d9ac35eba6794a4f9120d2db9932596ef8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dda8d021381083bc48b7fb1840729254dd8e5137.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ddcb1cfea1b0dbe50a02252cba99428fd977527e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dde93ffe7fca311e136e42fbcd12b05c9fc7174c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ddf5339054f47d9ed6cc7f9e66ab21ce3bccf3db.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de1ff66d2aeb47d2fdccaa4bb6b9d066b380c99e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de26a187c4db06115072a5132e1166b5b03368b0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de36bc309877917a18fd21acb30563c7e2f233c1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de5359f0fba3da9dfed06ddbea8fe2a33a9cf40c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de6683d175affaa5ff261ab8503f64172d8eba8b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de7eb562a7eff31d589e12945d80233aac202ae2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de85901d66dc04b1143bb6404445baf65693b781.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_deb9ec2cccab94920e40f62a1f0f094acd919d07.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df0b2bcba57e77d975ec5304fc50cbd09cddf4bb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df4bb75ca79f805a81fbad750ad22f6d22b0d8ff.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df4c9eb48da49a61957537270d94e56cb4e426be.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df5b1c6758d4b8540158299dd0362297083084c2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df645b3888dc8d1df50c47c0d75822eebd3eb019.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df66feebc9a0dcc508ce002c255154622875e524.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dfcd68acfca68d1acac94f493e25be0ef20f209f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e02a198f23c409b715761b702d7b0e6e5992701f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e035773419a9b3631698a3d375d829af55f7731e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e088f0f7363804cf5403adef70828ab32d09a02a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e0966fa1ff013e477b1706928de6cb7f8587c154.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e09d9baa269dfbb30b714389d1733be51cc419b7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e0e48d7edfe9513f24ad9fae68cac3aa940b17dd.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e10f47a44400de385ddbeb99475b717c5646fb41.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e11a3b7d4fdfed64e64f7a95dbc64eff541092d6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e13b86fe4e153e0bfa8d1e75f3641fe32b0c5149.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e16075c3a5fcfe63ba12e854bb1fed6873f014ab.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e16edb824cecf459a8ec51b8dc74b1e06369aceb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e1c1a31a1d8556cbe0b6ea76faacc78855108539.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e1cc934ba7baab1a2eb062df1e4ee5066e9ffbc3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e1d85ad2c9d197f501267fe0804e6985802fbd18.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2762543d3380185e304f84749a70db1b8d3dd8c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e28fd64c2f2b27577109a984e6ab82f5f0fcb296.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2b629c37cf94134693ce455b8c88b72a39df7fe.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2bf6805a489739abb77c13173d57723e9304afa.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2c9f955f227430c6224ebc347649386be7f01eb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2deafd2f36cee29109fb824e0135407453adcfe.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e3015c5d50481547aa5754d042d9d7040cf1c7ff.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e307a1b0d5a8f94e0a0f4032f401d20b4b643523.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e334e691714f0b99773c2ac515ed82de0f387065.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e34b7e452a4db74189334697e3a240ad68085f0e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e389d0e4442cd8304081892ddc75043e68a6398c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e465193d97d43237c22c04478ca5833011d8dc8b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e477abef05ff37ec27705eda51896e2aa3a04966.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e4d9a2396ceccdadab24602f30e9070901a76dc7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e502730dea6987e2c038446c448aa08bdcc23113.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e514c6b4bc75d95a150104a17972abae77cb47ed.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e52e3053f30f780f346fa6b7a836ad2554cb85df.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e56757fb17f5e94a6ba1fb14540a68c36d571159.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e578ec9e09d3b78dca6b5bf0be1538657f02f319.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5935fbda313d3518f142f43d46f56c600f69286.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5b2bb9f8466de1ad5210e4c39ee7b8ecacdffa9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5b65fc519ea7cfcd19f7eddbc3acad6842ff558.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5c5079636a4a31a849ce8a5af89d50330a74628.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5ccd5f7ddc894b2717112cbfc766804e02b7bd1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e618fb4e529104fc90069c8779ce5463460bd516.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e638053e01268a4c5883620fc6a9901951e2e01a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e639a1e84faa98477b05df71d363b9ff0f9b2760.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e68a9e05debd456a9975953f7b0d510e7a0f6978.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6973d75297bd2c3432a7c88e8a9ee1c9ae693bf.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6b53fb8d81148ff384d31a703bb4c2e7a5a33af.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6e0ec1db1ea308e226f675e68e29b839e41b252.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6e6b10e73733716e71ebf5a53703fb935fc5e02.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7153f9a9b0b7c54ddf2debbe297efcffbb4fcfa.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e73a776ae4ba68c23acab1a5a6381684051738ab.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e75c757c67aa23cb88e1aced6fcf36b7b28391db.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e75d492ac3a6ab75648056bcf26250a4aa929cfd.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e76879f8ff4796f48ad87ff8003f4f6e6adca9a0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7ae1294b6dea5c8b93c2b814fa7460c4047105b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7b2eb64b66d46359fab44333c2c484f4c9dd5de.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7c0a99e949baa5f3a7ee2d6e84427982f82f76d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7d37e7ee96c392fa24c02a9143438a3a7d05741.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7de729aa50c10d8101ef504138c3769e3286753.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e83c604d1b8260958becd1c7c209745ff9151715.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e89bcea4393593313d18a4aa6dcb44cd75bc828d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e8a9427f34bbf5ddb28a39161acc36806e68f2d0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e8d8fe5f4f8641998b8b805a20b2ca92d019ee59.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e8d9b65558398c0c10127b560807578ef117d7ed.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e907e8d1089557dfcc95a05160be5092e9119a53.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e95e3908479965856843317c8b0c42a6961dfd23.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e986d5f8d5591f3e0f1cdfad19c38c420fd93023.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e9b04e6d5527ba0b8089ba8bdd264e2d5759338b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e9b53fa68641f45baabf40b7cfb8b35a9a1b9c7f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea077e68dbc1bed2dd20a5f4dd35e0cad6330ee4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea591185b1c5f521023e250a26f742984255b241.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea62567e9ea16771d8445464c38f5a2931cb355a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea6a6d4cc262ea838dbb83ee747112f95fa297bc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eab6cdc59bf216f7045f0cf5f221bb91ec415cd2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eac353f963c52624cf79e82cc2b2c02eed94b677.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eac5952f46f4f2bf06257b00661774eeed48a323.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eb278488b2cca114adca5e4614d86f92447f937a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ebb241b947a0adfc8e50c5d71765c14af24593ae.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ebb9abf5b09e63cbe76390bb46ff7cbefb3141f0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec171210efd217c07d357fcf42e5372ad7e9abab.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec3deb1382003ac010d9bc1c59d1878d3ec7a727.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec51d24ab5f24e003ed6751ae8ae5b327892b15a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec7ec8d547ee9713aa3b5b667f22cdcaa8f62b2d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec7fc24902b1ebd8f2bf8088b0ecf6de8be8362d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec9f63a538940e5ace02ae5b5ddc01f730adac4d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eca613eaa8471ad7da66d2f8f2b8e07f6e02b467.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ecd7dec90b3c62bf3a30bd75d3c6869529a06b01.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ece60111633db08f765b3c7cd5cd768cbd030255.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ed37ba962e0288e2840eb0925d016b5a7e3b3164.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ed6bdf67720e938d538a867548ac3579b8238169.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ede81dbc4cb208ef6e684c76ba1eb451d37fe10c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee1a43f2210a8d1e5623411c95c33424cee5e747.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee239db5a67c23a383590a651f0d8a0be43a13c7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee8e709eec7aef1fa681053c6d2969a5ff18c45c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee974931e65d6b16b7c868d462b95dcae20b7513.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eeb0e96b759e18cf703cfab0cda1385726f6e0a1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eee408cf9456ff977aa7d12345e9b2f1e60639f1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef2ebb4a86e7ed0001de9c5e607b66fe8877409f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef40f0acf1885096efb840ec5600ec421c4db331.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef5421703cbfa63a58ec02701e245d479a1fbfc1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef7cc2aa1ffd38298b52764a93cd1271b4d92f8d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_efaa0cb33c71cb8ca7b83dd0e7a6c7b01f6b50a9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_efb9e7d9af47cdf79f15f674f8976c05f08b0ce8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_efc6a7b25710f0626c3af534111b161e1459d2e1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f01468c62c878295443981662e037ec5213cf7a3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f020134822739be6fa0bb3d98e9dec79f025324a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f0209426a8e6bfeef7d8ae7b16db791888142298.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f028af9e5e3c25800dde938e991aaab4fc1d64aa.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f053c9c32518b895daaa3521827f37af78836fb8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f069b38b26c30bc770f74c856e47eb498f5818e7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f0cad48d9bc80d58705ea60eb2dda4baad68cedb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f1246d1013d954a9316f4432c986d3be9459c548.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f12f1f1b679cabab04218037ef370d2c7e1fe332.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f15c41ddb04ec7f80235bb3db19198dd6b699713.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f18c74becc24a93427d9c0838784e9b6caad6e81.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f1ecc90ad7b86791a9e6f73a582aeff30f393804.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f21596e8c608a795ff971aea8e199db9e72b65d7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f24bd5b92ce6bba640b8ec6b4e53fe35902c5572.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f24d42e820adc1a26a428d59df7ffdd7f8580176.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f24f26e45d5cf567d29fbe375fbf8abdec39186f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f25b87c435bc5d7d85d738f3fdf68947d79f5a77.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f280e1639680ac1e5830a21f921bfe2cf364ef42.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f2da112b1e07c44fc8a7f19368da203f6935049c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f30316cfe49323638f71ba688dd8ff9b2266b335.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3193ea266f3718398bc5622f8bc7042c3527a42.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f34fdb8294257d951dcc9c4fa7ecf1192568b91b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f36aaa63ed42a578b953ebd614318d44cf44e8a3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f395bec57c3b2e6e169134dd8d20b287d7405134.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3bf7ef503bb026258b3ec3d82d3ef1443046964.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3d0166931e4406873d8f552a5d5b61fde2391a3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3fd08d56f8a9be1a8dd104cdb1ac58e283b5064.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3ff73f82aee3184849d04c2364eaa45c6d0de9c.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f42cf0e5fe479690883507028748b0cd3dc83cbb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4658c32d562f9d60c5ca1262a2e0df2375063bb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f48f8b681a405bfeba5aadaef40f32367ec5cd2b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4900c0a5c0d03dc17d7a907ab40652d9920e756.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4a6438394dd3427f29aa0bbe58ad1f797c3c38d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4b87f983a5e84582efa1663f84da76cf60b5f6f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4c803838f5644ccc6f04f7c8a6233fed0b6639e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4df1cbfbaf67705820f125b474469ad7ebab0c0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f50fa4ea674a590d0a817367ad9915a5fce20c51.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f51f1a11f778d99a00aa5959a3e58a41fcbfb1e3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f525b59df454ccf53da6cb201e0aa8d09f52a2ad.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f57f84892e2a8496169b7406e63b0d4f5aa63aaf.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f5803aadd93e33567aa6b23100ce4fbb6c040dd6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f5f1797f6b672a55476348571ce17645c8a62869.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f6566441ac3074578cfe45758ba0583c0da0a5ab.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f672bf80a78885428b2c02e522426470653a7351.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f682399cd6412fed6a1141296a7e4d42078f7b29.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f6856ca950bcf173571766c3f04de4163be0402e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f69548d6cced86c21c09c6475237a0cb926df0ed.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f69878f4ca8cfe6b8d8748766f66a1ef8eab20ad.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f6f102a388ffb05c690a20a29cfe0b35a35eed61.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f7035f4bfd8f2f427720a07e3c311bccc1dba683.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f71f96ce4dcc7f789a8ace73c230c203b05ff6dc.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f727911254904ce4341e4ff5f8bafc430b8cfbbf.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f731289837f915e2aec1bd01eef1b3c1b099864d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f79def2b4edf6d18f6ef1d6b141f9e0435441f6a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f7aa9c39b06e55bf4bc9f9a2a0fb075c9d4e69ce.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f7cf08242b3fb1c643d4149bec985b667b9d28fa.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f851da732f397624717160f89271514bc334b59b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f861d8693f82d22e2c5b1abbcbae5f30f4433e5e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f87790f260630f312b84888dcbdf849ce130ae59.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f87991cb7787a29d3ce4711b4ce04c5fb6a14ca9.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f90410c26d7649e21e2ae5e32e7af89d84d2ea70.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f92e9a82c879051d6fe3c42108f8a574187704af.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f93bc23b8a4f1e0fc5c5756c4e1c835bf59dea09.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f93bf815b520a9d9e17b43bf9d7fb870751b6225.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f974b12e83e214c30995a25631d37df1478927af.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f9824fb32933b27501ae8a7f43f460a2dda6a814.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f98a6b193fec3203eaa75819f6b51aa45a48f212.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f9c58761c927b222112cb5cb6c9acb5d3c915785.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fa16fa84278b489af253b52839786f94aeeac36f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fa62a97675719c2e8e9bb97361b92ff1c7b9d2ef.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fa85f869a92f0482605e52019828244b12e12b44.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fabdc143c29d5ca50ab1e96a814bda6d05b0d5d2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fac5a0f98b94530befd634891e42c424bb86f0e1.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fac99c3c82b77946f6844699d2333cd532a78a26.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_faf56e45b2240515e97fc1bfd552eb03b6de5094.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_faf686067fa433cea5e95dd523846dc881eff635.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb2fbb135d59028afcf867c2cf08edc323565528.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb4c15452f9155c5966990f09432e5eb7e28e785.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb4c5f8fecfbbe16e6648becb3b5ca89fa3d8a94.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb5bb49928ce5515d7b297d5eadd4ec70a22d60b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb79e1f9231692d736dbada062ed6821f34927bf.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb9477a613665cebcad781389ba7c5a36f51efe2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fba36678d5047ded97ee7a7ba9feb9569afdb6ea.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fba47fa8d9b5375bc408af68b67345ab9dba2eb8.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fbea85b766bf0c918ee0baf24dffc6a5563d5105.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fbeec221cd63adaedceec39db41ea942f99f5133.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc030b61ae20c4b7d9b2d10930a17e01e9e93328.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc1790325b59bd44b0a5f6cf9723a25fd845cba7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc1eb85a00017efdc610e4259d2abe935b85304f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc5841a729099340d608e31023acbeaeade3e886.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc5ebf0f2200f37ccc0849e0c3745f6e2f00111d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc7b0916744b593435d8e1e7b6d874d760cd5e3b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc86c13e933cba40553ffba31d53aad27415ce4b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fcb0b08e29b2e1bf181fceceb9dc416e54f52b00.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fcb6ef39c3db49f26f736d6c9221dd825409ec4e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fcbe827108d252b2f5847fa8e132c9c3e56a90a0.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fccabea88b8e290688c1b360875d228e6fdf1624.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd10a3b937e9659716925e39a01d794914b08e26.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd19d7614f2ed5da21a52ed172ef62cc07c9c01a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd26e43ca652e6f58ff48c356165aa4349833b55.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd345632e0cae0d549ba79626a08b1885711deb6.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd3558b4c7a667dbc365c4c2ceda646975408f51.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd614df484b263deae3b3c20adb0ce7b62eaa651.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd9cd1305633b62b68fb8474ce021f639f8492e7.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fde12cd366d6850ce26afce98e5076b695b4875b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe245e9ea974adce2b9807d33b9ba12d916eaffb.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe72cdd69944d2d765478d4aed13066a02b76f6d.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe8b8c3525fe86a20a2d6c69585f3e36c16caabd.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe97b7adcd67ed9bda8831d1f3f1ca7590c6d251.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe9d98dbec5096a89b116f85675af772f023014a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_feb5e77111fe1e20bafdb83a925b5faeeb6214af.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fecd7501265b4c4dcf015485e63e2324304f70d3.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fecffa403b3631b1957e1a9a06f18fdb3b4eee5f.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ff453e3bdc9752cb7b81f7cc3056325a8b9a8ad4.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ff6862dbdbb20bc63a650e1f93e9ac169bb702b2.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffb5b7349a671b182d73c8016590f26fe06a4cba.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffb8adef0cef91a86f36872407fea35df90e8f2b.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffc6056d9fe125a4dbe08c1d86354e51f7daadd5.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffd868d49abdb769ab82c21508d655daf54b8a99.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fff7aa57cca501f221077124359a589b3a6f9d0a.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fffbfcac254e33926131a71905e93f9cc0aef89e.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd.hpp create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/mask.hpp create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_bwd_ck.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/rename_ck_autogen_files.output.txt create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/rename_ck_autogen_files.sh create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/rotary.hpp create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/flash_common_hip.hpp diff --git a/LICENSE b/LICENSE index 9315c4efb68a..966a609b61e5 100644 --- a/LICENSE +++ b/LICENSE @@ -32,6 +32,10 @@ All contributions by Cruise LLC: Copyright (c) 2022 Cruise LLC. All rights reserved. +All contributions by Tri Dao: +Copyright (c) 2024 Tri Dao. +All rights reserved. + All contributions by Arm: Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 442ce7bbe890..9b526e0f2b86 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -168,9 +168,28 @@ file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu") file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu") file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp") -# flash_attention sources +# flash_attention hip sources file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip") -file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip") +# if USE_FLASH_ATTENTION is set, ensure CK instances get generated +if(USE_FLASH_ATTENTION) + if(DEFINED ENV{USE_CK_FLASH_ATTENTION}) + set(USE_CK_FLASH_ATTENTION $ENV{USE_CK_FLASH_ATTENTION}) + if(USE_CK_FLASH_ATTENTION STREQUAL "1") + if(DEFINED ENV{PYTORCH_ROCM_ARCH}) + list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS) + if(NUM_ARCHS GREATER 1) + message(WARNING "Building CK for multiple archs can increase build time considerably! + Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for") + endif() + endif() + message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled") + file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip") + list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip}) + endif() + endif() + file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip") + file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip") +endif() #Mem_eff attention sources file(GLOB mem_eff_attention_cuda_cu "native/transformers/cuda/mem_eff_attention/*.cu") @@ -185,6 +204,7 @@ if(USE_FLASH_ATTENTION) list(APPEND ATen_ATTENTION_KERNEL_SRCS ${flash_attention_cuda_kernels_cu}) list(APPEND native_transformers_hip_hip ${flash_attention_hip_hip}) + list(APPEND native_transformers_hip_hip ${flash_attention_hip_aot_hip}) list(APPEND native_transformers_src_hip_hip ${flash_attention_src_hip_hip}) endif() @@ -325,6 +345,9 @@ if(USE_ROCM) # Next two lines are needed because TunableOp uses third-party/fmt list(APPEND ATen_HIP_INCLUDE $) list(APPEND ATen_HIP_DEPENDENCY_LIBS fmt::fmt-header-only) +if(USE_FLASH_ATTENTION) + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck) +endif() list(APPEND ATen_HIP_SRCS ${ATen_HIP_SRCS} ${hip_hip} diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 3eb7c9375387..ee9c762fdb9b 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -343,6 +343,40 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) { #endif } +at::ROCmFABackend Context::getROCmFAPreferredBackend() const { + return rocm_fa_preferred_backend; +} + +void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) { + + // TODO: add plumbing for hasCK for validity checking + TORCH_CHECK((b != at::ROCmFABackend::Ck) || hasROCM(), + "Cannot set preferred flash attention backend to Ck if PyTorch has not been compiled for ROCm."); +#ifdef USE_ROCM + if(b == at::ROCmFABackend::Ck) { + static const bool ck_unsupported = []() { + static const std::vector archs = { + "gfx90a", "gfx942" + }; + for (auto index: c10::irange(getNumGPUs())) { + if (!detail::getCUDAHooks().isGPUArch(index, archs)) { + TORCH_WARN_ONCE( + "Attempting to use CK on an unsupported architecture! Cannot set backend to CK"); + return true; + } + } + return false; + }(); + if(!ck_unsupported) rocm_fa_preferred_backend = b; + } + else { + rocm_fa_preferred_backend = b; + } +#endif + rocm_fa_preferred_backend = b; +} + + bool Context::allowFP16ReductionCuBLAS() const { return allow_fp16_reduction_cublas; } diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index c5443c56a9ce..87f53c5f197c 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -239,6 +240,9 @@ class TORCH_API Context { at::BlasBackend blasPreferredBackend(); void setBlasPreferredBackend(at::BlasBackend); + at::ROCmFABackend getROCmFAPreferredBackend() const; + void setROCmFAPreferredBackend(at::ROCmFABackend); + // Note [Enabling Deterministic Operations] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Operations in PyTorch that normally act nondeterministically, but have an @@ -428,6 +432,10 @@ class TORCH_API Context { #endif ? at::BlasBackend::Cublaslt : at::BlasBackend::Cublas; + at::ROCmFABackend rocm_fa_preferred_backend = + c10::utils::check_env("TORCH_ROCM_FA_PREFER_CK") == true + ? at::ROCmFABackend::Ck + : at::ROCmFABackend::Default; #ifdef C10_MOBILE bool release_original_weights = true; #else diff --git a/aten/src/ATen/ROCmFABackend.h b/aten/src/ATen/ROCmFABackend.h new file mode 100644 index 000000000000..6e2844cc8be1 --- /dev/null +++ b/aten/src/ATen/ROCmFABackend.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +#include +#include + +namespace at { + +enum class ROCmFABackend : int8_t { Default, AOTriton, Ck }; + +inline std::string ROCmFABackendToString(at::ROCmFABackend backend) { + switch (backend) { + case ROCmFABackend::Default: + return "at::ROCmFABackend::Default"; + case ROCmFABackend::AOTriton: + return "at::ROCmFABackend::AOTriton"; + case ROCmFABackend::Ck: + return "at::ROCmFABackend::Ck"; + default: + TORCH_CHECK(false, "Unknown ROCm flash attention backend") + } +} + +inline std::ostream& operator<<( + std::ostream& stream, + at::ROCmFABackend backend) { + return stream << ROCmFABackendToString(backend); +} + +} // namespace at diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 80eb89600da2..c83889cd4bec 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -28,7 +28,7 @@ #if USE_ROCM #if defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION) #include -#define USE_AOTRITON 1 +#define USE_ROCM_ATTENTION 1 #endif #endif @@ -219,15 +219,21 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug using sm80 = SMVersion<8, 0>; using sm90 = SMVersion<9, 0>; #if USE_ROCM -#if USE_AOTRITON - auto stream = at::cuda::getCurrentCUDAStream().stream(); - if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { - auto dprops = at::cuda::getCurrentDeviceProperties(); - if (debug) { - TORCH_WARN( - "Flash attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName); - } - return false; +#if USE_ROCM_ATTENTION + if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { + // User explicitly set CK as the flash attention backend. Return true for now + // TODO: Flesh out sanity checks + return true; + } else { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + if (debug) { + TORCH_WARN( + "Flash attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName); + } + return false; + } } #else return false; @@ -254,7 +260,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) using sm50 = SMVersion<5, 0>; using sm90 = SMVersion<9, 0>; #if USE_ROCM -#if USE_AOTRITON +#if USE_ROCM_ATTENTION auto stream = at::cuda::getCurrentCUDAStream().stream(); if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { auto dprops = at::cuda::getCurrentDeviceProperties(); diff --git a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h index 1709bf4d0595..11f83ffef368 100644 --- a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h +++ b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h @@ -124,7 +124,7 @@ inline aotriton::TensorView<0> mk_aoscalartensor(const at::Tensor& q) inline aotriton::TensorView<0> mk_philoxtensor(const int64_t* ptr) { return aotriton::TensorView<0>(reinterpret_cast(ptr), - aotriton::DType::kUInt64); // AOTriton excepts unsigned int64 + aotriton::DType::kUInt64); // AOTriton accepts unsigned int64 } } // namespace aotriton_adapter diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip similarity index 96% rename from aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip rename to aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip index dcbac79e317d..598105ecef18 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip @@ -115,24 +115,18 @@ prepare_philox_arguments(float p_dropout, int64_t counter_offset) { #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") std::tuple -mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size - std::optional &out_, // batch_size x seqlen_q x num_heads x head_size - std::optional &alibi_slopes_, // num_heads or batch_size x num_heads - const float p_dropout, - const float softmax_scale, - bool is_causal, - int window_size_left, - int window_size_right, - const bool return_softmax, - std::optional gen_) { - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - // [ROCM specific]: must be at the beginning of the function - // Otherwise check_gpu_arch() checks cuda:0 device. - at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; - +mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &out_, // batch_size x seqlen_q x num_heads x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_) { auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); check_gpu_arch(stream); @@ -242,7 +236,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head } std::tuple -mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i +mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i @@ -408,7 +402,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q } std::tuple -mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og +mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size @@ -559,7 +553,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si } std::tuple -mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size +mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i @@ -747,7 +741,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size return { dq, dk, dv, softmax_d }; } - -} // namespace pytorch_fmha +} // namespace pytorch_flash #endif diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/bias.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/bias.hpp new file mode 100644 index 000000000000..8115288fb887 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/bias.hpp @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +// keep sync with BlockAttentionBiasEnum +enum class bias_enum +{ + no_bias = 0, + elementwise_bias = 1, + alibi = 2, +}; + +struct bias_info +{ + bias_enum type; + /* + * simple dispatch logic + * + * if type == elementwise_bias: + * if rank_info == 0: + * bias is 1*1*s*s + * elif rank_info == 1: + * bias is 1*h*s*s + * elif rank_info == 2: + * bias is b*h*s*s + * + * elif type == alibi: + * if rank_info == 0: + * alibi in 1*h + * elif rank_info == 1: + * alibi in b*h + */ + int rank_info; + + void serialize(std::ostream& os) const + { + if(type == bias_enum::no_bias) + os << "n"; + else if(type == bias_enum::elementwise_bias) + { + os << "e"; + if(rank_info != 0) + { + os << "[" << rank_info << "]"; + } + } + else if(type == bias_enum::alibi) + { + os << "alibi"; + if(rank_info != 0) + { + os << "[" << rank_info << "]"; + } + } + } + + static bias_info decode(std::string str) + { + bias_info info{bias_enum::no_bias, 0}; + if(str == "0" || str == "n") + { + info.type = bias_enum::no_bias; + } + else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 || + str.compare(0, 11, "elementwise") == 0) + { + info.type = bias_enum::elementwise_bias; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string e = str.substr(found_0 + 1); + info.rank_info = atoi(e.c_str()); + } + } + else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 || + str.compare(0, 5, "alibi") == 0) + { + info.type = bias_enum::alibi; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string e = str.substr(found_0 + 1); + info.rank_info = atoi(e.c_str()); + } + } + return info; + } + + friend std::ostream& operator<<(std::ostream& os, const bias_info& bi) + { + bi.serialize(os); + return os; + } +}; diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp new file mode 100644 index 000000000000..2f21bc13622a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp @@ -0,0 +1,447 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +template +struct FmhaBwdTypeConfig; + +template <> +struct FmhaBwdTypeConfig +{ + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using GemmDataType = ck_tile::half_t; + using BiasDataType = ck_tile::half_t; + using LSEDataType = float; + using AccDataType = float; // data type for gemm accumulation + using DDataType = float; + using RandValOutputDataType = uint8_t; + using ODataType = ck_tile::half_t; + using OGradDataType = ck_tile::half_t; + using QGradDataType = ck_tile::half_t; + using KGradDataType = ck_tile::half_t; + using VGradDataType = ck_tile::half_t; + using BiasGradDataType = ck_tile::half_t; +}; + +template <> +struct FmhaBwdTypeConfig +{ + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using GemmDataType = ck_tile::bf16_t; + using BiasDataType = ck_tile::bf16_t; + using LSEDataType = float; + using AccDataType = float; // data type for gemm accumulation + using DDataType = float; + using RandValOutputDataType = uint8_t; + using ODataType = ck_tile::bf16_t; + using OGradDataType = ck_tile::bf16_t; + using QGradDataType = ck_tile::bf16_t; + using KGradDataType = ck_tile::bf16_t; + using VGradDataType = ck_tile::bf16_t; + using BiasGradDataType = ck_tile::bf16_t; +}; + +struct FmhaMasks +{ + using NoMask = ck_tile::GenericAttentionMask; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; + +// runtime args, some will passed to karg, some will used to compute grids/blocks +struct fmha_bwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + const void* o_ptr; + const void* lse_ptr; + const void* do_ptr; + void* d_ptr; + void* rand_val_ptr; + void* dq_ptr; + void* dk_ptr; + void* dv_ptr; + void* dbias_ptr; + void* dq_acc_ptr; + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t max_seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + float scale; + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_o; + ck_tile::index_t stride_randval; + ck_tile::index_t stride_do; + ck_tile::index_t stride_dq_acc; + ck_tile::index_t stride_dq; + ck_tile::index_t stride_dk; + ck_tile::index_t stride_dv; + ck_tile::index_t stride_dbias; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_do; + ck_tile::index_t nhead_stride_lsed; + ck_tile::index_t nhead_stride_dq_acc; + ck_tile::index_t nhead_stride_dq; + ck_tile::index_t nhead_stride_dk; + ck_tile::index_t nhead_stride_dv; + ck_tile::index_t nhead_stride_dbias; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_o; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_do; + ck_tile::index_t batch_stride_lsed; + ck_tile::index_t batch_stride_dq_acc; + ck_tile::index_t batch_stride_dq; + ck_tile::index_t batch_stride_dk; + ck_tile::index_t batch_stride_dv; + ck_tile::index_t batch_stride_dbias; + ck_tile::index_t split_stride_dq_acc; + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + float p_drop; + float p_undrop; + std::variant, std::pair> + drop_seed_offset; +}; + +template +auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode) + { + return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + args.dq_acc_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + args.stride_dq_acc, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + args.nhead_stride_dq_acc, + args.nhead_stride_dk, + args.nhead_stride_dv, + args.nhead_stride_dbias, + args.split_stride_dq_acc, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + args.dq_acc_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + args.stride_dq_acc, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + args.nhead_stride_dq_acc, + args.nhead_stride_dk, + args.nhead_stride_dv, + args.nhead_stride_dbias, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_do, + args.batch_stride_lsed, + args.batch_stride_dq_acc, + args.batch_stride_dk, + args.batch_stride_dv, + args.batch_stride_dbias, + args.split_stride_dq_acc, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.drop_seed_offset); + } + }(); + + dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k); + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) +{ + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode) + { + return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, + args.do_ptr, + args.d_ptr, + args.p_undrop, + args.seqstart_q_ptr, + args.hdim_v, + args.stride_do, + args.stride_o, + args.nhead_stride_do, + args.nhead_stride_o, + args.nhead_stride_lsed); + } + else + { // create batch mode kernel arguments + return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, + args.do_ptr, + args.d_ptr, + args.p_undrop, + args.seqlen_q, + args.hdim_v, + args.stride_do, + args.stride_o, + args.nhead_stride_do, + args.nhead_stride_o, + args.nhead_stride_lsed, + args.batch_stride_do, + args.batch_stride_o, + args.batch_stride_lsed); + } + }(); + + dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q); + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) +{ + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode) + { + return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr, + args.dq_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.hdim_q, + args.stride_dq, + args.stride_dq_acc, + args.nhead_stride_dq, + args.nhead_stride_dq_acc, + args.split_stride_dq_acc); + } + else + { // create batch mode kernel arguments + return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr, + args.dq_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.stride_dq, + args.stride_dq_acc, + args.nhead_stride_dq, + args.nhead_stride_dq_acc, + args.batch_stride_dq, + args.batch_stride_dq_acc, + args.split_stride_dq_acc); + } + }(); + + dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q); + return ck_tile::make_tuple(kargs, grids); +} + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_bwd_dq_dk_dv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_; + using FmhaMask = ck_tile::remove_cvref_t; + using FmhaDropout = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kHasBiasGrad = kHasBiasGrad_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr bool kIsDeterministic = kIsDeterministic_; +}; + +template +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args); + +template +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); + +template +std::string fmha_bwd_dq_dk_dv_get_name_(); + +template +struct fmha_bwd_dot_do_o_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args); + +template +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); + +template +std::string fmha_bwd_dot_do_o_get_name_(); + +template +struct fmha_bwd_convert_dq_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kIsDeterministic = kIsDeterministic_; +}; + +template +float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args); + +template +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); + +template +std::string fmha_bwd_convert_dq_get_name_(); + +// This is the public API, will be generated by script +struct fmha_bwd_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_dbias; + bool has_dropout; + bool is_store_randval; + bool is_deterministic; + // TODO: padding check is inside this api +}; +float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_00042c36bc588e60a7c8a9ba297a8a25d8ac0660.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_00042c36bc588e60a7c8a9ba297a8a25d8ac0660.hip new file mode 100644 index 000000000000..11ff05e5dbdc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_00042c36bc588e60a7c8a9ba297a8a25d8ac0660.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0029076f83a3dc695a167beda6fe19230a2b114b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0029076f83a3dc695a167beda6fe19230a2b114b.hip new file mode 100644 index 000000000000..e50f11e48b23 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0029076f83a3dc695a167beda6fe19230a2b114b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_006c417a52a1bd7c55e45d111483d26f4480caeb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_006c417a52a1bd7c55e45d111483d26f4480caeb.hip new file mode 100644 index 000000000000..f4235b476ee6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_006c417a52a1bd7c55e45d111483d26f4480caeb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_008f2429c678d13386a06e8d8b15c4b480940ff3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_008f2429c678d13386a06e8d8b15c4b480940ff3.hip new file mode 100644 index 000000000000..1fcb150bb0c6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_008f2429c678d13386a06e8d8b15c4b480940ff3.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_00a2adbe938d458d51ca5fc4020667a215b672a4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_00a2adbe938d458d51ca5fc4020667a215b672a4.hip new file mode 100644 index 000000000000..3da78aabd73a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_00a2adbe938d458d51ca5fc4020667a215b672a4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_012c0f480917c329f4c3c6c666cf32af2d82b294.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_012c0f480917c329f4c3c6c666cf32af2d82b294.hip new file mode 100644 index 000000000000..a5ee51bba50e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_012c0f480917c329f4c3c6c666cf32af2d82b294.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_014c209d5cfc6b965bfd78c64bf132c0154e32be.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_014c209d5cfc6b965bfd78c64bf132c0154e32be.hip new file mode 100644 index 000000000000..d23c280ce1c3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_014c209d5cfc6b965bfd78c64bf132c0154e32be.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0153ec18d3ded0f8bdc6459ea5757ebd94d9faf2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0153ec18d3ded0f8bdc6459ea5757ebd94d9faf2.hip new file mode 100644 index 000000000000..b8272029ad30 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0153ec18d3ded0f8bdc6459ea5757ebd94d9faf2.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01ac1a2ecf9a487809e46faa92e267df2d47de91.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01ac1a2ecf9a487809e46faa92e267df2d47de91.hip new file mode 100644 index 000000000000..c7cb12d5cd1a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01ac1a2ecf9a487809e46faa92e267df2d47de91.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01ca79005067e20e4eed5a72ff9187cde702cd1c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01ca79005067e20e4eed5a72ff9187cde702cd1c.hip new file mode 100644 index 000000000000..9cc45b050f30 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01ca79005067e20e4eed5a72ff9187cde702cd1c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01cb354dddef6e99e4ac843f2adafcddfc58d520.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01cb354dddef6e99e4ac843f2adafcddfc58d520.hip new file mode 100644 index 000000000000..304214921a42 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01cb354dddef6e99e4ac843f2adafcddfc58d520.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01d12033d59ce2799a2a024e5d9232325ccf1320.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01d12033d59ce2799a2a024e5d9232325ccf1320.hip new file mode 100644 index 000000000000..12d5d208b29e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01d12033d59ce2799a2a024e5d9232325ccf1320.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01d3b034a2d8d0b83c0aefa4faac6c3f28ce737f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01d3b034a2d8d0b83c0aefa4faac6c3f28ce737f.hip new file mode 100644 index 000000000000..24c0a414244a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01d3b034a2d8d0b83c0aefa4faac6c3f28ce737f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01e2428c5447aa9a78f79f73f31cf685c586872d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01e2428c5447aa9a78f79f73f31cf685c586872d.hip new file mode 100644 index 000000000000..34117453c799 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01e2428c5447aa9a78f79f73f31cf685c586872d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01e8aedb7b7d77f44a46b2e9b7a826f245aaf4a7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01e8aedb7b7d77f44a46b2e9b7a826f245aaf4a7.hip new file mode 100644 index 000000000000..7dff16907280 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01e8aedb7b7d77f44a46b2e9b7a826f245aaf4a7.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01e8f0df0c54ce619e5b66441b3c96a5e18b05d6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01e8f0df0c54ce619e5b66441b3c96a5e18b05d6.hip new file mode 100644 index 000000000000..247ecb28f7a7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01e8f0df0c54ce619e5b66441b3c96a5e18b05d6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01ee0083f6df962c4a754cd3295b1a436c590a0e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01ee0083f6df962c4a754cd3295b1a436c590a0e.hip new file mode 100644 index 000000000000..c83349a6662a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01ee0083f6df962c4a754cd3295b1a436c590a0e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01f74764c3c3284fdd1b67d0ea781c2261ed0de6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01f74764c3c3284fdd1b67d0ea781c2261ed0de6.hip new file mode 100644 index 000000000000..ecf4290b6700 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_01f74764c3c3284fdd1b67d0ea781c2261ed0de6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0225857454eaab2eb664aef7a0849ce12c32fdf9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0225857454eaab2eb664aef7a0849ce12c32fdf9.hip new file mode 100644 index 000000000000..2dcf470cb492 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0225857454eaab2eb664aef7a0849ce12c32fdf9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0237c76137df14fb808ade8bd6837045f2aaa5c9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0237c76137df14fb808ade8bd6837045f2aaa5c9.hip new file mode 100644 index 000000000000..780a28248b4f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0237c76137df14fb808ade8bd6837045f2aaa5c9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0271bd8b7c270e1593871b638288a4923342c446.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0271bd8b7c270e1593871b638288a4923342c446.hip new file mode 100644 index 000000000000..7df54cbcb255 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0271bd8b7c270e1593871b638288a4923342c446.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_02d88a03cd3966dd0cff550065f58c3ffecfff6c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_02d88a03cd3966dd0cff550065f58c3ffecfff6c.hip new file mode 100644 index 000000000000..898d17870ed7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_02d88a03cd3966dd0cff550065f58c3ffecfff6c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_02ff94e3c787a7b06ffc90c25777fa74f225e32c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_02ff94e3c787a7b06ffc90c25777fa74f225e32c.hip new file mode 100644 index 000000000000..7e08824282cb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_02ff94e3c787a7b06ffc90c25777fa74f225e32c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_030a759dcc92028b4c6f317fc230b98cb929e806.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_030a759dcc92028b4c6f317fc230b98cb929e806.hip new file mode 100644 index 000000000000..cc44db887586 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_030a759dcc92028b4c6f317fc230b98cb929e806.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_031b12f9fd94e01aaff2c0da4f35f346822087e4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_031b12f9fd94e01aaff2c0da4f35f346822087e4.hip new file mode 100644 index 000000000000..2f4c053bc747 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_031b12f9fd94e01aaff2c0da4f35f346822087e4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_036887daf6cc092e7422a17882488e59cecfb643.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_036887daf6cc092e7422a17882488e59cecfb643.hip new file mode 100644 index 000000000000..cd13097b8181 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_036887daf6cc092e7422a17882488e59cecfb643.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_037c6c80fcec3eb8b0bef50ad6af6d27bf5447f5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_037c6c80fcec3eb8b0bef50ad6af6d27bf5447f5.hip new file mode 100644 index 000000000000..ad0c065328ab --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_037c6c80fcec3eb8b0bef50ad6af6d27bf5447f5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0392491c5a6dfc742c2be483419a40f6a7a7ea56.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0392491c5a6dfc742c2be483419a40f6a7a7ea56.hip new file mode 100644 index 000000000000..913e1819e46d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0392491c5a6dfc742c2be483419a40f6a7a7ea56.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_03a71615a088e972c998f9c7cb44566c268c5124.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_03a71615a088e972c998f9c7cb44566c268c5124.hip new file mode 100644 index 000000000000..d48e083be95e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_03a71615a088e972c998f9c7cb44566c268c5124.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_03ff035717140f7385282419598cb4fb2881ce8e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_03ff035717140f7385282419598cb4fb2881ce8e.hip new file mode 100644 index 000000000000..12bfd890c6ff --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_03ff035717140f7385282419598cb4fb2881ce8e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_041a0718891596ddac1fb0088637029233ccbe60.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_041a0718891596ddac1fb0088637029233ccbe60.hip new file mode 100644 index 000000000000..ed3cc93a059e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_041a0718891596ddac1fb0088637029233ccbe60.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_042a156e9eb935555ab14a84461959b466c2fb5b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_042a156e9eb935555ab14a84461959b466c2fb5b.hip new file mode 100644 index 000000000000..8a0da15aa5ee --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_042a156e9eb935555ab14a84461959b466c2fb5b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04641230fe9a50a221047f7a1df8a370f72805b9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04641230fe9a50a221047f7a1df8a370f72805b9.hip new file mode 100644 index 000000000000..4ed1496285db --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04641230fe9a50a221047f7a1df8a370f72805b9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04c363e11d202c6d2f4bb753661c5a2043edc0ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04c363e11d202c6d2f4bb753661c5a2043edc0ad.hip new file mode 100644 index 000000000000..74be5c1f1952 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04c363e11d202c6d2f4bb753661c5a2043edc0ad.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04caeecbc01667ec6f5599358a0a20423aa9a00b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04caeecbc01667ec6f5599358a0a20423aa9a00b.hip new file mode 100644 index 000000000000..81b58c933d5e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04caeecbc01667ec6f5599358a0a20423aa9a00b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04f39b453505f68a5091f68b1c3de48369d1e7ea.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04f39b453505f68a5091f68b1c3de48369d1e7ea.hip new file mode 100644 index 000000000000..513ed9b0ba6e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04f39b453505f68a5091f68b1c3de48369d1e7ea.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04ffca078cfab8bc6c4ccd1cc8994a1bb4a88ea7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04ffca078cfab8bc6c4ccd1cc8994a1bb4a88ea7.hip new file mode 100644 index 000000000000..7498169d408d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_04ffca078cfab8bc6c4ccd1cc8994a1bb4a88ea7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0502e718337eab7d47aa65cea7d3c5f641484520.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0502e718337eab7d47aa65cea7d3c5f641484520.hip new file mode 100644 index 000000000000..4248c1f8bc47 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0502e718337eab7d47aa65cea7d3c5f641484520.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0513b2f3bd8ad51315aadb7f63737201898adca8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0513b2f3bd8ad51315aadb7f63737201898adca8.hip new file mode 100644 index 000000000000..d2f0ff760beb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0513b2f3bd8ad51315aadb7f63737201898adca8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_053981d9e7af2ebc0f91e61ac5e25cbe68c95bd8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_053981d9e7af2ebc0f91e61ac5e25cbe68c95bd8.hip new file mode 100644 index 000000000000..e9188fb65ab1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_053981d9e7af2ebc0f91e61ac5e25cbe68c95bd8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_054fda16133a0d25077967b05425f9128e1fe1a5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_054fda16133a0d25077967b05425f9128e1fe1a5.hip new file mode 100644 index 000000000000..515bcf6b65e6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_054fda16133a0d25077967b05425f9128e1fe1a5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05538339c21c92c53d237865d72debaaf2ee5075.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05538339c21c92c53d237865d72debaaf2ee5075.hip new file mode 100644 index 000000000000..dd0fb7aea274 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05538339c21c92c53d237865d72debaaf2ee5075.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0595316f0dfffda03e5296b959a49ec3f3c48d67.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0595316f0dfffda03e5296b959a49ec3f3c48d67.hip new file mode 100644 index 000000000000..8a8d141623d9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0595316f0dfffda03e5296b959a49ec3f3c48d67.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05dfe927fd64a564c5fad537fb7c41ee9c94c2c0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05dfe927fd64a564c5fad537fb7c41ee9c94c2c0.hip new file mode 100644 index 000000000000..af08208640a9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05dfe927fd64a564c5fad537fb7c41ee9c94c2c0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05e60b3ab7477f9edc8576a8bf43e3a62b8d5ef8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05e60b3ab7477f9edc8576a8bf43e3a62b8d5ef8.hip new file mode 100644 index 000000000000..6173f431eac6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05e60b3ab7477f9edc8576a8bf43e3a62b8d5ef8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05f794c7023cbb7e35f1fd1ae45bd2377bfbc520.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05f794c7023cbb7e35f1fd1ae45bd2377bfbc520.hip new file mode 100644 index 000000000000..a461a8ad3c6c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_05f794c7023cbb7e35f1fd1ae45bd2377bfbc520.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0628931bf5cc1daa6e106cf60bb21fa1aac6b1df.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0628931bf5cc1daa6e106cf60bb21fa1aac6b1df.hip new file mode 100644 index 000000000000..36a2eba38f7d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0628931bf5cc1daa6e106cf60bb21fa1aac6b1df.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_062c8c3c1cf6c33af4574099e9b6ac54a55ad776.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_062c8c3c1cf6c33af4574099e9b6ac54a55ad776.hip new file mode 100644 index 000000000000..9fd67589da95 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_062c8c3c1cf6c33af4574099e9b6ac54a55ad776.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0682150e93f547e00f13cd8984779bf49b91e50c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0682150e93f547e00f13cd8984779bf49b91e50c.hip new file mode 100644 index 000000000000..2200cc8523f4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0682150e93f547e00f13cd8984779bf49b91e50c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_069c663be0267c009be4814e9e4e7c13ec999411.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_069c663be0267c009be4814e9e4e7c13ec999411.hip new file mode 100644 index 000000000000..e443f8ec7296 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_069c663be0267c009be4814e9e4e7c13ec999411.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_06ae52ef937cc27c544e32025ea0dadb7fad982d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_06ae52ef937cc27c544e32025ea0dadb7fad982d.hip new file mode 100644 index 000000000000..6100c507d6c3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_06ae52ef937cc27c544e32025ea0dadb7fad982d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_06b74acd9abfbd1c4ec2f4c718eeb92a0bca7bab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_06b74acd9abfbd1c4ec2f4c718eeb92a0bca7bab.hip new file mode 100644 index 000000000000..e59ac40198df --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_06b74acd9abfbd1c4ec2f4c718eeb92a0bca7bab.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_06ba94794a14f0f0022af6f5f3c16e1e16959d4c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_06ba94794a14f0f0022af6f5f3c16e1e16959d4c.hip new file mode 100644 index 000000000000..9487fc53035d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_06ba94794a14f0f0022af6f5f3c16e1e16959d4c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_071751b1012b90f7b57f8591cd06ae1fd27d9cd3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_071751b1012b90f7b57f8591cd06ae1fd27d9cd3.hip new file mode 100644 index 000000000000..d437c6cd60ee --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_071751b1012b90f7b57f8591cd06ae1fd27d9cd3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0766e7aa4b263a811408b285213e47176ee2bdaf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0766e7aa4b263a811408b285213e47176ee2bdaf.hip new file mode 100644 index 000000000000..f952c1535c96 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0766e7aa4b263a811408b285213e47176ee2bdaf.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_076b3beb57b30afb30636f948e3989b346b38d20.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_076b3beb57b30afb30636f948e3989b346b38d20.hip new file mode 100644 index 000000000000..0babc637454c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_076b3beb57b30afb30636f948e3989b346b38d20.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0789852b0cd3cc030c78b28f2fd5b6b0546382a4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0789852b0cd3cc030c78b28f2fd5b6b0546382a4.hip new file mode 100644 index 000000000000..866ebbec981b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0789852b0cd3cc030c78b28f2fd5b6b0546382a4.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_078b96ad691a85eebd18586db0b62b8911016d9c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_078b96ad691a85eebd18586db0b62b8911016d9c.hip new file mode 100644 index 000000000000..64c878eccfe8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_078b96ad691a85eebd18586db0b62b8911016d9c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_07c3fc96d2bebe546dce6ebf46e5c7a519959599.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_07c3fc96d2bebe546dce6ebf46e5c7a519959599.hip new file mode 100644 index 000000000000..883582e9c67e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_07c3fc96d2bebe546dce6ebf46e5c7a519959599.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_07ff04fcc273e469737512893ea3fb5876ac131d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_07ff04fcc273e469737512893ea3fb5876ac131d.hip new file mode 100644 index 000000000000..2819cd974115 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_07ff04fcc273e469737512893ea3fb5876ac131d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0801c56831b4c6428200db6318638a2129bb197a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0801c56831b4c6428200db6318638a2129bb197a.hip new file mode 100644 index 000000000000..1d293ba529bb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0801c56831b4c6428200db6318638a2129bb197a.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + false, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0836d5dfc0f939ab9a4064b403339373caf35b56.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0836d5dfc0f939ab9a4064b403339373caf35b56.hip new file mode 100644 index 000000000000..66ca2006e4a7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0836d5dfc0f939ab9a4064b403339373caf35b56.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0842c4e3aabdf55405b3ce09ce1899245ddf11ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0842c4e3aabdf55405b3ce09ce1899245ddf11ad.hip new file mode 100644 index 000000000000..2a53b540abec --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0842c4e3aabdf55405b3ce09ce1899245ddf11ad.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_085722b43cde5f37242edb071f639da7c4a0bd48.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_085722b43cde5f37242edb071f639da7c4a0bd48.hip new file mode 100644 index 000000000000..f42aa3ef4909 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_085722b43cde5f37242edb071f639da7c4a0bd48.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0878b9aa31429d23a93cd953cc6a2fc5f43d0d3a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0878b9aa31429d23a93cd953cc6a2fc5f43d0d3a.hip new file mode 100644 index 000000000000..256a6393e9da --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0878b9aa31429d23a93cd953cc6a2fc5f43d0d3a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_089a347aef8a920e3b59d5ffe71fc5bfe002609c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_089a347aef8a920e3b59d5ffe71fc5bfe002609c.hip new file mode 100644 index 000000000000..9463f524b13e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_089a347aef8a920e3b59d5ffe71fc5bfe002609c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_089de13222caec1483207d4a54249f8da4f9c151.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_089de13222caec1483207d4a54249f8da4f9c151.hip new file mode 100644 index 000000000000..e21b2f479e50 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_089de13222caec1483207d4a54249f8da4f9c151.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_091cb49c1958fb4342d79f367ea93cf2b472f785.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_091cb49c1958fb4342d79f367ea93cf2b472f785.hip new file mode 100644 index 000000000000..e6a493bdbd0e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_091cb49c1958fb4342d79f367ea93cf2b472f785.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_093834d4d3fe76e1745e4482c6b51b550c6f3dfc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_093834d4d3fe76e1745e4482c6b51b550c6f3dfc.hip new file mode 100644 index 000000000000..0f9deffe70e1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_093834d4d3fe76e1745e4482c6b51b550c6f3dfc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09513bff5c1da6aadf11d2e8272a422eabff21bc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09513bff5c1da6aadf11d2e8272a422eabff21bc.hip new file mode 100644 index 000000000000..656d540789c7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09513bff5c1da6aadf11d2e8272a422eabff21bc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_096863cd93d1b105a617d0daa1d4f37d7fb6b893.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_096863cd93d1b105a617d0daa1d4f37d7fb6b893.hip new file mode 100644 index 000000000000..69bc9277c75b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_096863cd93d1b105a617d0daa1d4f37d7fb6b893.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0968cebd81ade762c2f92fffc0153fa7a2b91eb5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0968cebd81ade762c2f92fffc0153fa7a2b91eb5.hip new file mode 100644 index 000000000000..474717b1a205 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0968cebd81ade762c2f92fffc0153fa7a2b91eb5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_096e888c52d0f4a5847d7515fcc66208b1ff40d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_096e888c52d0f4a5847d7515fcc66208b1ff40d3.hip new file mode 100644 index 000000000000..74f365d61861 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_096e888c52d0f4a5847d7515fcc66208b1ff40d3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_097b3e1dae9bfb2e89398706508f8e01966fd4ea.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_097b3e1dae9bfb2e89398706508f8e01966fd4ea.hip new file mode 100644 index 000000000000..fa7da3a38549 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_097b3e1dae9bfb2e89398706508f8e01966fd4ea.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09d76cca48b71dbcc9bd96734787209fee4c9a74.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09d76cca48b71dbcc9bd96734787209fee4c9a74.hip new file mode 100644 index 000000000000..d51aeab7bf12 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09d76cca48b71dbcc9bd96734787209fee4c9a74.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09e50367b62bb09071e28b44235a7c112645a706.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09e50367b62bb09071e28b44235a7c112645a706.hip new file mode 100644 index 000000000000..94cc58dc3149 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09e50367b62bb09071e28b44235a7c112645a706.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09ecb6347009f6a5d5530a6acf90f9f40288cbcf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09ecb6347009f6a5d5530a6acf90f9f40288cbcf.hip new file mode 100644 index 000000000000..544b7c9a8b46 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_09ecb6347009f6a5d5530a6acf90f9f40288cbcf.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a2b116fd5065109aae46ee547e4f49ad0e9d6e1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a2b116fd5065109aae46ee547e4f49ad0e9d6e1.hip new file mode 100644 index 000000000000..abd896784782 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a2b116fd5065109aae46ee547e4f49ad0e9d6e1.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a4e76d89b175e1d9fd2e9fb908d5fce1ebb945d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a4e76d89b175e1d9fd2e9fb908d5fce1ebb945d.hip new file mode 100644 index 000000000000..e6671df3102f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a4e76d89b175e1d9fd2e9fb908d5fce1ebb945d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a55ed15ef58c941e06dda890aeb530e28eb7bba.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a55ed15ef58c941e06dda890aeb530e28eb7bba.hip new file mode 100644 index 000000000000..29800976b6ff --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a55ed15ef58c941e06dda890aeb530e28eb7bba.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a672fca51de618e3441cf8764e8e83eb782f2c7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a672fca51de618e3441cf8764e8e83eb782f2c7.hip new file mode 100644 index 000000000000..e48d26d4247e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a672fca51de618e3441cf8764e8e83eb782f2c7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a68c2f9a3acdd787b81be455cbc7836c8bfd90c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a68c2f9a3acdd787b81be455cbc7836c8bfd90c.hip new file mode 100644 index 000000000000..27872f6fae40 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a68c2f9a3acdd787b81be455cbc7836c8bfd90c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a89417a043556970f72eebd48b4f3e7ac15377a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a89417a043556970f72eebd48b4f3e7ac15377a.hip new file mode 100644 index 000000000000..eee7d6d67f14 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a89417a043556970f72eebd48b4f3e7ac15377a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a92671b6ea99891c0d69b1c793f4d131b9a82ed.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a92671b6ea99891c0d69b1c793f4d131b9a82ed.hip new file mode 100644 index 000000000000..da711799b094 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0a92671b6ea99891c0d69b1c793f4d131b9a82ed.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0aafb881e34a3794970a1282af740b3f19c138b1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0aafb881e34a3794970a1282af740b3f19c138b1.hip new file mode 100644 index 000000000000..673bab07358f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0aafb881e34a3794970a1282af740b3f19c138b1.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ace6e29e1d3060c3086c08fe27b471e375f9c75.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ace6e29e1d3060c3086c08fe27b471e375f9c75.hip new file mode 100644 index 000000000000..322083d0ac3a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ace6e29e1d3060c3086c08fe27b471e375f9c75.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ad9d68fcee021437e13ffdf94d78252205f5a31.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ad9d68fcee021437e13ffdf94d78252205f5a31.hip new file mode 100644 index 000000000000..647879bfac53 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ad9d68fcee021437e13ffdf94d78252205f5a31.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b2647b5982405a48e8c8888552a4b89386ccdd9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b2647b5982405a48e8c8888552a4b89386ccdd9.hip new file mode 100644 index 000000000000..f63d42b45ce5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b2647b5982405a48e8c8888552a4b89386ccdd9.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b2efefea81036641561bed80c75d77651176f74.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b2efefea81036641561bed80c75d77651176f74.hip new file mode 100644 index 000000000000..ff32eecf1e7b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b2efefea81036641561bed80c75d77651176f74.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b3153af7bcdba33115a0d31f121fd76be2ffbcc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b3153af7bcdba33115a0d31f121fd76be2ffbcc.hip new file mode 100644 index 000000000000..85408d275fe7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b3153af7bcdba33115a0d31f121fd76be2ffbcc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b532fcf26f90c82a792cde7943634f667c1d033.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b532fcf26f90c82a792cde7943634f667c1d033.hip new file mode 100644 index 000000000000..21b0570baada --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b532fcf26f90c82a792cde7943634f667c1d033.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b90a0186d8b8004e3f19886c7992c8e04d0e066.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b90a0186d8b8004e3f19886c7992c8e04d0e066.hip new file mode 100644 index 000000000000..2b98ae636738 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b90a0186d8b8004e3f19886c7992c8e04d0e066.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b9585ba1c10acf67115c5899b3546608541820d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b9585ba1c10acf67115c5899b3546608541820d.hip new file mode 100644 index 000000000000..7520a8551e45 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0b9585ba1c10acf67115c5899b3546608541820d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0bb81407c8a2b3cdc5fecf655b3ad64d5d729cc9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0bb81407c8a2b3cdc5fecf655b3ad64d5d729cc9.hip new file mode 100644 index 000000000000..324bb25c1207 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0bb81407c8a2b3cdc5fecf655b3ad64d5d729cc9.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0bc7910aac798f0555e9e505ad7f177c9fbbd92c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0bc7910aac798f0555e9e505ad7f177c9fbbd92c.hip new file mode 100644 index 000000000000..fde747894ac6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0bc7910aac798f0555e9e505ad7f177c9fbbd92c.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0be8cf70c6be969ecfca675782c860b5b75ac089.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0be8cf70c6be969ecfca675782c860b5b75ac089.hip new file mode 100644 index 000000000000..e3bcce7ab85e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0be8cf70c6be969ecfca675782c860b5b75ac089.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0befed50a89d80c22b2c8c3d5ba67d73c3d0190e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0befed50a89d80c22b2c8c3d5ba67d73c3d0190e.hip new file mode 100644 index 000000000000..0a92513862c6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0befed50a89d80c22b2c8c3d5ba67d73c3d0190e.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c32a2d9701e23dd930119c4ee8089042b5b0ac5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c32a2d9701e23dd930119c4ee8089042b5b0ac5.hip new file mode 100644 index 000000000000..19dc9a2ff499 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c32a2d9701e23dd930119c4ee8089042b5b0ac5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c3b2ec99fa7b09c7f78dcc3142a661d686044ac.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c3b2ec99fa7b09c7f78dcc3142a661d686044ac.hip new file mode 100644 index 000000000000..e7b6c90da08f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c3b2ec99fa7b09c7f78dcc3142a661d686044ac.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c8a0bb89a6f05289c0405df5126fa0cc16252e7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c8a0bb89a6f05289c0405df5126fa0cc16252e7.hip new file mode 100644 index 000000000000..ea61a8c916ec --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c8a0bb89a6f05289c0405df5126fa0cc16252e7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c93c65e5942a2f43f2e491547add02777dd2eee.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c93c65e5942a2f43f2e491547add02777dd2eee.hip new file mode 100644 index 000000000000..2d8d09cf2eb6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c93c65e5942a2f43f2e491547add02777dd2eee.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c9bd38b8f9009d932ec49204fdea39a52885246.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c9bd38b8f9009d932ec49204fdea39a52885246.hip new file mode 100644 index 000000000000..8473617ba95d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0c9bd38b8f9009d932ec49204fdea39a52885246.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0caeedaa7d50f1741d618fb6c573529eebb075b1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0caeedaa7d50f1741d618fb6c573529eebb075b1.hip new file mode 100644 index 000000000000..78b0ce5c6f20 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0caeedaa7d50f1741d618fb6c573529eebb075b1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0cdef49859c80c6b3ba18eb2fb4c35c72abc1cf2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0cdef49859c80c6b3ba18eb2fb4c35c72abc1cf2.hip new file mode 100644 index 000000000000..1ce29ae2f992 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0cdef49859c80c6b3ba18eb2fb4c35c72abc1cf2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0cee6b9427c164d78994150305a47f73954a67c0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0cee6b9427c164d78994150305a47f73954a67c0.hip new file mode 100644 index 000000000000..c54a987cac53 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0cee6b9427c164d78994150305a47f73954a67c0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0d0e0147a92061d32608a34e7b47bd534eb787fa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0d0e0147a92061d32608a34e7b47bd534eb787fa.hip new file mode 100644 index 000000000000..78f56b54f70e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0d0e0147a92061d32608a34e7b47bd534eb787fa.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0d13a4c8d169877da6408584dc1f20a6f7c5e3aa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0d13a4c8d169877da6408584dc1f20a6f7c5e3aa.hip new file mode 100644 index 000000000000..78dfc5a9b1ac --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0d13a4c8d169877da6408584dc1f20a6f7c5e3aa.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0dde401aa76cb5425563cbbdb0362748148da3ca.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0dde401aa76cb5425563cbbdb0362748148da3ca.hip new file mode 100644 index 000000000000..db5239ecd29f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0dde401aa76cb5425563cbbdb0362748148da3ca.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e007c36231ccdae12f102eacca1f74b0711b9c6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e007c36231ccdae12f102eacca1f74b0711b9c6.hip new file mode 100644 index 000000000000..79a2821f3f75 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e007c36231ccdae12f102eacca1f74b0711b9c6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e0a2370f2a320484d8f9f21e3197425c2dbe9ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e0a2370f2a320484d8f9f21e3197425c2dbe9ad.hip new file mode 100644 index 000000000000..11364420bf65 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e0a2370f2a320484d8f9f21e3197425c2dbe9ad.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e1dbc9c433ce8ec33ace9e62550261d613db582.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e1dbc9c433ce8ec33ace9e62550261d613db582.hip new file mode 100644 index 000000000000..db47fbad4d8d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e1dbc9c433ce8ec33ace9e62550261d613db582.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e3f4cd28a4c06cc109f6a0798a77844bcc750b7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e3f4cd28a4c06cc109f6a0798a77844bcc750b7.hip new file mode 100644 index 000000000000..c086f7b6e86d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e3f4cd28a4c06cc109f6a0798a77844bcc750b7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e661b5f30566d1f159f060c264849c7ae4772f1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e661b5f30566d1f159f060c264849c7ae4772f1.hip new file mode 100644 index 000000000000..644f756253b8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0e661b5f30566d1f159f060c264849c7ae4772f1.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ebacd06455ab20eba78b389462946716b5819f6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ebacd06455ab20eba78b389462946716b5819f6.hip new file mode 100644 index 000000000000..59ec37be84d0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ebacd06455ab20eba78b389462946716b5819f6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ef309b923172f4c0fb38d9b9f5325b33b4877c2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ef309b923172f4c0fb38d9b9f5325b33b4877c2.hip new file mode 100644 index 000000000000..8a1b12f59bdf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ef309b923172f4c0fb38d9b9f5325b33b4877c2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ef9b9413697d6f4573c6605bff6f58d027c5016.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ef9b9413697d6f4573c6605bff6f58d027c5016.hip new file mode 100644 index 000000000000..8182111c4a43 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0ef9b9413697d6f4573c6605bff6f58d027c5016.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0efdaa9266a5a464009297dc59db92504f8bf1a3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0efdaa9266a5a464009297dc59db92504f8bf1a3.hip new file mode 100644 index 000000000000..e6cf970a99a0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0efdaa9266a5a464009297dc59db92504f8bf1a3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0f0c699d9c3b0ed62097e38ba05e40e815cf474e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0f0c699d9c3b0ed62097e38ba05e40e815cf474e.hip new file mode 100644 index 000000000000..d6cfa8a50928 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0f0c699d9c3b0ed62097e38ba05e40e815cf474e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0f588dcb2ef86677ebf84e406eb802e9921d1f1e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0f588dcb2ef86677ebf84e406eb802e9921d1f1e.hip new file mode 100644 index 000000000000..ccfa82d52c18 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0f588dcb2ef86677ebf84e406eb802e9921d1f1e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fbb0bef3b388867e75d7a8a187b8b4b650a42ae.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fbb0bef3b388867e75d7a8a187b8b4b650a42ae.hip new file mode 100644 index 000000000000..05fde574742b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fbb0bef3b388867e75d7a8a187b8b4b650a42ae.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fbddf533661642d84bf5a16149692d5a892182a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fbddf533661642d84bf5a16149692d5a892182a.hip new file mode 100644 index 000000000000..4b4b70a80f2f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fbddf533661642d84bf5a16149692d5a892182a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fcb7492feb79e27e0bda73e57ef7dab410e2bb6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fcb7492feb79e27e0bda73e57ef7dab410e2bb6.hip new file mode 100644 index 000000000000..3ee6242ee005 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fcb7492feb79e27e0bda73e57ef7dab410e2bb6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fd4068ea93fcf4df463e3bf3a6898d23b65da7f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fd4068ea93fcf4df463e3bf3a6898d23b65da7f.hip new file mode 100644 index 000000000000..3e609731a5de --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_0fd4068ea93fcf4df463e3bf3a6898d23b65da7f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_103186dbad604763008e0204a1ea90baecef8877.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_103186dbad604763008e0204a1ea90baecef8877.hip new file mode 100644 index 000000000000..0380a534950b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_103186dbad604763008e0204a1ea90baecef8877.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1037f1bc50c4a65dac09ba56b701256b701c4322.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1037f1bc50c4a65dac09ba56b701256b701c4322.hip new file mode 100644 index 000000000000..24ece6c55a71 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1037f1bc50c4a65dac09ba56b701256b701c4322.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_10a055e5c3d6a953d470db5dc21449766248058a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_10a055e5c3d6a953d470db5dc21449766248058a.hip new file mode 100644 index 000000000000..1cd4a3a3f7f7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_10a055e5c3d6a953d470db5dc21449766248058a.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + true, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_10c24f1f9009e46afa3a59193784cc2575f79056.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_10c24f1f9009e46afa3a59193784cc2575f79056.hip new file mode 100644 index 000000000000..1a9606077e1e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_10c24f1f9009e46afa3a59193784cc2575f79056.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_10ceed95b0a0a01f844678717c88e0426fb503fd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_10ceed95b0a0a01f844678717c88e0426fb503fd.hip new file mode 100644 index 000000000000..ba338ae6a363 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_10ceed95b0a0a01f844678717c88e0426fb503fd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1132b11429034d96d82c82dbfdb69e460ad8a564.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1132b11429034d96d82c82dbfdb69e460ad8a564.hip new file mode 100644 index 000000000000..10220b3ed4df --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1132b11429034d96d82c82dbfdb69e460ad8a564.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_11e7df31541c3aa919e9825ad7dc4432f9a03c0c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_11e7df31541c3aa919e9825ad7dc4432f9a03c0c.hip new file mode 100644 index 000000000000..6bb1fa391544 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_11e7df31541c3aa919e9825ad7dc4432f9a03c0c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_11ff174ff2175e9ec22ac3a0fa59dd7713b79643.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_11ff174ff2175e9ec22ac3a0fa59dd7713b79643.hip new file mode 100644 index 000000000000..876953de89ee --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_11ff174ff2175e9ec22ac3a0fa59dd7713b79643.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1211733062ed30b876f1d63bffa642d77e258dd6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1211733062ed30b876f1d63bffa642d77e258dd6.hip new file mode 100644 index 000000000000..5b0886c2e104 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1211733062ed30b876f1d63bffa642d77e258dd6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_12207f4b6e7fac27d6c16493a5373f448a2aaae8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_12207f4b6e7fac27d6c16493a5373f448a2aaae8.hip new file mode 100644 index 000000000000..91425485ae58 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_12207f4b6e7fac27d6c16493a5373f448a2aaae8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1241814f76107d74ed069ecec99a248676487eee.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1241814f76107d74ed069ecec99a248676487eee.hip new file mode 100644 index 000000000000..7f0d93f40f72 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1241814f76107d74ed069ecec99a248676487eee.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_12d5c8a4988efe60ef7943ecd73e18a28a736583.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_12d5c8a4988efe60ef7943ecd73e18a28a736583.hip new file mode 100644 index 000000000000..0948511ba533 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_12d5c8a4988efe60ef7943ecd73e18a28a736583.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_12d60c8abecb3bc9b84b0ea7851628ab17d8b0b3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_12d60c8abecb3bc9b84b0ea7851628ab17d8b0b3.hip new file mode 100644 index 000000000000..996900f70099 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_12d60c8abecb3bc9b84b0ea7851628ab17d8b0b3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_131691f01cc7f29affb88152dd48c7a484315dcd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_131691f01cc7f29affb88152dd48c7a484315dcd.hip new file mode 100644 index 000000000000..9c4507d8322a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_131691f01cc7f29affb88152dd48c7a484315dcd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_131c1fdc4206bb952b2fea675f24e3b09f605eef.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_131c1fdc4206bb952b2fea675f24e3b09f605eef.hip new file mode 100644 index 000000000000..8129f0f62d6e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_131c1fdc4206bb952b2fea675f24e3b09f605eef.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_133c51948cf8584900807998da14d788039f53b9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_133c51948cf8584900807998da14d788039f53b9.hip new file mode 100644 index 000000000000..07dfd3b3131a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_133c51948cf8584900807998da14d788039f53b9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_135ea67de101135ed5fe04f5cab1ec1d7b3714bb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_135ea67de101135ed5fe04f5cab1ec1d7b3714bb.hip new file mode 100644 index 000000000000..30a8f33ad895 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_135ea67de101135ed5fe04f5cab1ec1d7b3714bb.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_137fa6780d9e6bde10aec10a875c039fdbbc652e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_137fa6780d9e6bde10aec10a875c039fdbbc652e.hip new file mode 100644 index 000000000000..1eb2afd6ab7c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_137fa6780d9e6bde10aec10a875c039fdbbc652e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1386cd75411e61a8dbbaf2b916e62f4f5f99104f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1386cd75411e61a8dbbaf2b916e62f4f5f99104f.hip new file mode 100644 index 000000000000..86e10c62ce9d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1386cd75411e61a8dbbaf2b916e62f4f5f99104f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_13d5f2ec83b3331654e37ea0b44d88cd98abaa37.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_13d5f2ec83b3331654e37ea0b44d88cd98abaa37.hip new file mode 100644 index 000000000000..cb483891a41a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_13d5f2ec83b3331654e37ea0b44d88cd98abaa37.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_13f747525ad31e76c88774fb2208e470da9c2310.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_13f747525ad31e76c88774fb2208e470da9c2310.hip new file mode 100644 index 000000000000..cc7e1a08d5d9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_13f747525ad31e76c88774fb2208e470da9c2310.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14221590b90c48d3cf259fb4e834ccfaf7f3209b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14221590b90c48d3cf259fb4e834ccfaf7f3209b.hip new file mode 100644 index 000000000000..032c888c89e5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14221590b90c48d3cf259fb4e834ccfaf7f3209b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_144f19363ef26efd36f0436cfa9f84f181a8824c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_144f19363ef26efd36f0436cfa9f84f181a8824c.hip new file mode 100644 index 000000000000..31d8a42c2d17 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_144f19363ef26efd36f0436cfa9f84f181a8824c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_146eb8c40e3146e06936f3141b2c4d92a578ddec.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_146eb8c40e3146e06936f3141b2c4d92a578ddec.hip new file mode 100644 index 000000000000..47e94e66edd1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_146eb8c40e3146e06936f3141b2c4d92a578ddec.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14baaaf1e90a075ab802c6e7d97c4b1605c8bd72.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14baaaf1e90a075ab802c6e7d97c4b1605c8bd72.hip new file mode 100644 index 000000000000..bfd55686121b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14baaaf1e90a075ab802c6e7d97c4b1605c8bd72.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14c4ebd1792c781d219bd21b691b575f64635730.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14c4ebd1792c781d219bd21b691b575f64635730.hip new file mode 100644 index 000000000000..b8d8881a17db --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14c4ebd1792c781d219bd21b691b575f64635730.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14d11aad7b666f500f68b264a2fcca6dfc5f1a05.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14d11aad7b666f500f68b264a2fcca6dfc5f1a05.hip new file mode 100644 index 000000000000..7d5612962a1d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14d11aad7b666f500f68b264a2fcca6dfc5f1a05.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14d4630876785655bd4950566e81ae0b645c0d3c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14d4630876785655bd4950566e81ae0b645c0d3c.hip new file mode 100644 index 000000000000..f8ac18f98929 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14d4630876785655bd4950566e81ae0b645c0d3c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14f77aeeafe4b28f314fde5ebccfd2a554872781.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14f77aeeafe4b28f314fde5ebccfd2a554872781.hip new file mode 100644 index 000000000000..39a73952876d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14f77aeeafe4b28f314fde5ebccfd2a554872781.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14fea611f3c253aebf726af3e5fdb7e63e18e13a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14fea611f3c253aebf726af3e5fdb7e63e18e13a.hip new file mode 100644 index 000000000000..1c9a01c07017 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_14fea611f3c253aebf726af3e5fdb7e63e18e13a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_151a4425b411596c46c7032f6b83d3152a0e0cd4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_151a4425b411596c46c7032f6b83d3152a0e0cd4.hip new file mode 100644 index 000000000000..61478645543d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_151a4425b411596c46c7032f6b83d3152a0e0cd4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_153e897098539c3466da9d7a37234daf16476277.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_153e897098539c3466da9d7a37234daf16476277.hip new file mode 100644 index 000000000000..5d0bf7e6ec09 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_153e897098539c3466da9d7a37234daf16476277.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1552dc38d26f6badb7a9bcb5ce9124d54cc45ed3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1552dc38d26f6badb7a9bcb5ce9124d54cc45ed3.hip new file mode 100644 index 000000000000..d90d53b8a775 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1552dc38d26f6badb7a9bcb5ce9124d54cc45ed3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_155bafb551768855c8c01faa63e44764ebe6c110.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_155bafb551768855c8c01faa63e44764ebe6c110.hip new file mode 100644 index 000000000000..2f6d56c17afb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_155bafb551768855c8c01faa63e44764ebe6c110.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_155c3549d067464d186a99b8205317cc000d4898.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_155c3549d067464d186a99b8205317cc000d4898.hip new file mode 100644 index 000000000000..eb7e1a339282 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_155c3549d067464d186a99b8205317cc000d4898.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1573e3d855d28c54af612ab950b081302891d56d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1573e3d855d28c54af612ab950b081302891d56d.hip new file mode 100644 index 000000000000..abd04eb76ca4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1573e3d855d28c54af612ab950b081302891d56d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_157768cd725813f8111d265cfdfea7f42034e5e9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_157768cd725813f8111d265cfdfea7f42034e5e9.hip new file mode 100644 index 000000000000..4e2fc118bc27 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_157768cd725813f8111d265cfdfea7f42034e5e9.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_157b89d8d625b8244b5cceaa4d3e5fc5a09c8989.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_157b89d8d625b8244b5cceaa4d3e5fc5a09c8989.hip new file mode 100644 index 000000000000..0c4b0b4ad5b5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_157b89d8d625b8244b5cceaa4d3e5fc5a09c8989.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_158d5ce564c3ae1eefb54e3d41dde2604560ef4a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_158d5ce564c3ae1eefb54e3d41dde2604560ef4a.hip new file mode 100644 index 000000000000..793d43cad59d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_158d5ce564c3ae1eefb54e3d41dde2604560ef4a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_159ee1f1b44d1a8fbaead65d8449413bb616d15e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_159ee1f1b44d1a8fbaead65d8449413bb616d15e.hip new file mode 100644 index 000000000000..6006941b6997 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_159ee1f1b44d1a8fbaead65d8449413bb616d15e.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15b255dde1a9d915e582ee2a83de7d83190c6a24.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15b255dde1a9d915e582ee2a83de7d83190c6a24.hip new file mode 100644 index 000000000000..dc4ea63a6b69 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15b255dde1a9d915e582ee2a83de7d83190c6a24.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15cf7068183421b141ed5d6e7fe902d06b6492a1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15cf7068183421b141ed5d6e7fe902d06b6492a1.hip new file mode 100644 index 000000000000..260c3559a59a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15cf7068183421b141ed5d6e7fe902d06b6492a1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15dc02ea7e0908cf0bd48034f5a49debfaa36219.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15dc02ea7e0908cf0bd48034f5a49debfaa36219.hip new file mode 100644 index 000000000000..46fb46ef11a7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15dc02ea7e0908cf0bd48034f5a49debfaa36219.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15e8e1ab8c63db96843054bb7a98d708ae6a9c44.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15e8e1ab8c63db96843054bb7a98d708ae6a9c44.hip new file mode 100644 index 000000000000..91b4d887de6f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15e8e1ab8c63db96843054bb7a98d708ae6a9c44.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15fe3e8f4add16a088fe44458353fa7c0c4f9658.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15fe3e8f4add16a088fe44458353fa7c0c4f9658.hip new file mode 100644 index 000000000000..65fac8b48c5c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_15fe3e8f4add16a088fe44458353fa7c0c4f9658.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_16047b5544acef40e39932672cac6f562e200948.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_16047b5544acef40e39932672cac6f562e200948.hip new file mode 100644 index 000000000000..85f56bdca7a2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_16047b5544acef40e39932672cac6f562e200948.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1621507cf219fe608715d4e5bb6e5764022e2d61.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1621507cf219fe608715d4e5bb6e5764022e2d61.hip new file mode 100644 index 000000000000..a80c66dfaae6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1621507cf219fe608715d4e5bb6e5764022e2d61.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_162b0dfbe3f615b1d164290799b2457437a0044b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_162b0dfbe3f615b1d164290799b2457437a0044b.hip new file mode 100644 index 000000000000..8ed77fdae461 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_162b0dfbe3f615b1d164290799b2457437a0044b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_164a947a6c2ba83a5b1cb7074aee0bdac6c9c64e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_164a947a6c2ba83a5b1cb7074aee0bdac6c9c64e.hip new file mode 100644 index 000000000000..55dc832eb938 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_164a947a6c2ba83a5b1cb7074aee0bdac6c9c64e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_165dfb45658df8f1ae8dc0738ac9614740f2576c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_165dfb45658df8f1ae8dc0738ac9614740f2576c.hip new file mode 100644 index 000000000000..68ef183d6990 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_165dfb45658df8f1ae8dc0738ac9614740f2576c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_167f5328b035ed59a6f05dfee31edd704c4b07ee.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_167f5328b035ed59a6f05dfee31edd704c4b07ee.hip new file mode 100644 index 000000000000..d3adafe0b286 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_167f5328b035ed59a6f05dfee31edd704c4b07ee.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1687ddf65ce4ed2997583e20fee9f201e86633b3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1687ddf65ce4ed2997583e20fee9f201e86633b3.hip new file mode 100644 index 000000000000..4f5830b7d444 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1687ddf65ce4ed2997583e20fee9f201e86633b3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_16f94f5c65c37624f5458c165daf83517d9e3c81.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_16f94f5c65c37624f5458c165daf83517d9e3c81.hip new file mode 100644 index 000000000000..86c7f2cba0eb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_16f94f5c65c37624f5458c165daf83517d9e3c81.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_173c44dd85077e6b12dd06fdcf6b11ba349e1866.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_173c44dd85077e6b12dd06fdcf6b11ba349e1866.hip new file mode 100644 index 000000000000..919a3ac828c6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_173c44dd85077e6b12dd06fdcf6b11ba349e1866.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_17b9b96edda151072215502cc2b606bf1f6f0b03.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_17b9b96edda151072215502cc2b606bf1f6f0b03.hip new file mode 100644 index 000000000000..d41b5dcef34f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_17b9b96edda151072215502cc2b606bf1f6f0b03.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1847fef2c06ea581b0ab31af1cb0556c572696ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1847fef2c06ea581b0ab31af1cb0556c572696ad.hip new file mode 100644 index 000000000000..d9bc327995bc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1847fef2c06ea581b0ab31af1cb0556c572696ad.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_187963e1969301abfa61d06afc97faea2bb4efb1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_187963e1969301abfa61d06afc97faea2bb4efb1.hip new file mode 100644 index 000000000000..da9c2fe71e23 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_187963e1969301abfa61d06afc97faea2bb4efb1.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1886d4bf54b3a4a9e093360998b2059b3c03d072.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1886d4bf54b3a4a9e093360998b2059b3c03d072.hip new file mode 100644 index 000000000000..a0ac88f518c7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1886d4bf54b3a4a9e093360998b2059b3c03d072.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_188a70d526394e254274df95de0727850820326c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_188a70d526394e254274df95de0727850820326c.hip new file mode 100644 index 000000000000..62ef602df6e7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_188a70d526394e254274df95de0727850820326c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1899e28aff2fb168cdc3af7132dd7fd09c2e1ced.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1899e28aff2fb168cdc3af7132dd7fd09c2e1ced.hip new file mode 100644 index 000000000000..bdc5cbba2aef --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1899e28aff2fb168cdc3af7132dd7fd09c2e1ced.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_18a4d71b31c451a50df7996e3db864bc3c3882ed.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_18a4d71b31c451a50df7996e3db864bc3c3882ed.hip new file mode 100644 index 000000000000..997da7f6a12c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_18a4d71b31c451a50df7996e3db864bc3c3882ed.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_18b92b4e249195ac3e0c74d246585a4c9e0992fd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_18b92b4e249195ac3e0c74d246585a4c9e0992fd.hip new file mode 100644 index 000000000000..304a37b44890 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_18b92b4e249195ac3e0c74d246585a4c9e0992fd.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_18ed7195a9443c84956c3f32839cb3ab9056bdfc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_18ed7195a9443c84956c3f32839cb3ab9056bdfc.hip new file mode 100644 index 000000000000..179dbf59798e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_18ed7195a9443c84956c3f32839cb3ab9056bdfc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1914250fce818584291c69a5f058a58cfbd83df9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1914250fce818584291c69a5f058a58cfbd83df9.hip new file mode 100644 index 000000000000..ffab9800ad51 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1914250fce818584291c69a5f058a58cfbd83df9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_193699a5daa14ca2def07489e0b563149bc403f8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_193699a5daa14ca2def07489e0b563149bc403f8.hip new file mode 100644 index 000000000000..996f9773c4c9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_193699a5daa14ca2def07489e0b563149bc403f8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_19af6a7f9e5020e8d0f0ca0f6258001f6ce592c1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_19af6a7f9e5020e8d0f0ca0f6258001f6ce592c1.hip new file mode 100644 index 000000000000..8eee4ffa35da --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_19af6a7f9e5020e8d0f0ca0f6258001f6ce592c1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_19cd9f7b08cec83736605af63d9fcaf463a1aea4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_19cd9f7b08cec83736605af63d9fcaf463a1aea4.hip new file mode 100644 index 000000000000..b353542ec8ce --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_19cd9f7b08cec83736605af63d9fcaf463a1aea4.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_19df4e13108e043361e9528b71df56f04f696a0c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_19df4e13108e043361e9528b71df56f04f696a0c.hip new file mode 100644 index 000000000000..69ecf815ab6e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_19df4e13108e043361e9528b71df56f04f696a0c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a11dd5ebb989503a1c182684e7f247e2f8cd9c2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a11dd5ebb989503a1c182684e7f247e2f8cd9c2.hip new file mode 100644 index 000000000000..d03aa0275bb2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a11dd5ebb989503a1c182684e7f247e2f8cd9c2.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a236be9da05a07d11cd28034d90cdf89941a172.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a236be9da05a07d11cd28034d90cdf89941a172.hip new file mode 100644 index 000000000000..765925ad32e9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a236be9da05a07d11cd28034d90cdf89941a172.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a5e18f6333ed2cce509f07cb8bd5868951d66a0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a5e18f6333ed2cce509f07cb8bd5868951d66a0.hip new file mode 100644 index 000000000000..d3ceedc3ba29 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a5e18f6333ed2cce509f07cb8bd5868951d66a0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a6785392af35e27d6697b584cb6f17a766d3fee.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a6785392af35e27d6697b584cb6f17a766d3fee.hip new file mode 100644 index 000000000000..45412a0172dd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a6785392af35e27d6697b584cb6f17a766d3fee.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a6bc2762b95d550485aa720edaf71138d94cd07.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a6bc2762b95d550485aa720edaf71138d94cd07.hip new file mode 100644 index 000000000000..9d1782fd4f6d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a6bc2762b95d550485aa720edaf71138d94cd07.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a8da3e6ab050262b659c801ccf9a14787d7f176.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a8da3e6ab050262b659c801ccf9a14787d7f176.hip new file mode 100644 index 000000000000..81ea5cb5728b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a8da3e6ab050262b659c801ccf9a14787d7f176.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a96f0ac76f117e66eba97cb990c2350561ec2ab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a96f0ac76f117e66eba97cb990c2350561ec2ab.hip new file mode 100644 index 000000000000..07305c4de118 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a96f0ac76f117e66eba97cb990c2350561ec2ab.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a98bcbe900f8c141136d18c114b02fffbe8bca1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a98bcbe900f8c141136d18c114b02fffbe8bca1.hip new file mode 100644 index 000000000000..084f97abecca --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a98bcbe900f8c141136d18c114b02fffbe8bca1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a99b2625adffa8215276bb88fc65bae944b846b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a99b2625adffa8215276bb88fc65bae944b846b.hip new file mode 100644 index 000000000000..6c9b49e2ebd3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1a99b2625adffa8215276bb88fc65bae944b846b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1acf2f892742b1d236d2b31a8185c6869126adad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1acf2f892742b1d236d2b31a8185c6869126adad.hip new file mode 100644 index 000000000000..08cb9227ae14 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1acf2f892742b1d236d2b31a8185c6869126adad.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1b3e7c8969027d3316875f33dc50fe022e05ce37.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1b3e7c8969027d3316875f33dc50fe022e05ce37.hip new file mode 100644 index 000000000000..f32d6823f03e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1b3e7c8969027d3316875f33dc50fe022e05ce37.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + false, + false, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1be43f8b629e7039f57b95866d5777273377470d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1be43f8b629e7039f57b95866d5777273377470d.hip new file mode 100644 index 000000000000..de2671c488ad --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1be43f8b629e7039f57b95866d5777273377470d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1be746990a2032f0363ad9f9112cc994983f4706.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1be746990a2032f0363ad9f9112cc994983f4706.hip new file mode 100644 index 000000000000..7605150e8319 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1be746990a2032f0363ad9f9112cc994983f4706.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1bf767e7104cfc8322f26df35907fbf04b8948f3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1bf767e7104cfc8322f26df35907fbf04b8948f3.hip new file mode 100644 index 000000000000..f37ebaa9855e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1bf767e7104cfc8322f26df35907fbf04b8948f3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1c1b0f85e085dd0769c566fb16aafe5ab5952714.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1c1b0f85e085dd0769c566fb16aafe5ab5952714.hip new file mode 100644 index 000000000000..5437e8aac265 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1c1b0f85e085dd0769c566fb16aafe5ab5952714.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1c2a2d78176e3f0a78e3ad78217e75a4430c0de5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1c2a2d78176e3f0a78e3ad78217e75a4430c0de5.hip new file mode 100644 index 000000000000..903d320799ea --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1c2a2d78176e3f0a78e3ad78217e75a4430c0de5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1c65ba6dba01da9caa84ba89453b61d81376763f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1c65ba6dba01da9caa84ba89453b61d81376763f.hip new file mode 100644 index 000000000000..cc965b6f7079 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1c65ba6dba01da9caa84ba89453b61d81376763f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1ca3f45d0be2d1119cccd0af042a3e8adeda2ed7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1ca3f45d0be2d1119cccd0af042a3e8adeda2ed7.hip new file mode 100644 index 000000000000..30608a330e5a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1ca3f45d0be2d1119cccd0af042a3e8adeda2ed7.hip @@ -0,0 +1,1965 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){ + float r = -1; + if(t.data_type.compare("fp16") == 0){ + if (t.hdim_q <= 32 && t.hdim_v <= 32) { + if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::fp16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + + } + else if (t.hdim_q <= 64 && t.hdim_v <= 64) { + if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::fp16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + + } + else if (t.hdim_q <= 128 && t.hdim_v <= 128) { + if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::fp16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + + } + else if (t.hdim_q <= 256 && t.hdim_v <= 256) { + if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::fp16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + + } + + } + else if(t.data_type.compare("bf16") == 0){ + if (t.hdim_q <= 32 && t.hdim_v <= 32) { + if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, false, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<32, ck_tile::bf16_t, true, 128, 64, 16, 32, 32, 32, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + + } + else if (t.hdim_q <= 64 && t.hdim_v <= 64) { + if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 64 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 64 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, false, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<64, ck_tile::bf16_t, true, 128, 64, 32, 64, 32, 64, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + + } + else if (t.hdim_q <= 128 && t.hdim_v <= 128) { + if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k != 0 && a.seqlen_k % 128 == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true) && (a.seqlen_k == 0 || a.seqlen_k % 128 != 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, false, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + using trait_ = fmha_fwd_traits_<128, ck_tile::bf16_t, true, 128, 128, 32, 128, 32, 128, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + + } + else if (t.hdim_q <= 256 && t.hdim_v <= 256) { + if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (a.seqlen_q % 128 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false, false>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == false) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true /*a.seqlen_q % 128 != 0*/) && (true /*a.seqlen_k % 128 != 0*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, false, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == true) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == true) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + else if((t.is_group_mode == true) && (t.is_v_rowmajor == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_lse == false) && (t.has_dropout == false) && (t.do_fp8_static_quant == false) && + (true/*group mode spad always true*/) && (true/*group mode skpad always true*/) && (true /*a.hdim_q % 256 != 0*/) && (true /*a.hdim_v % 256 != 0*/)) { + using trait_ = fmha_fwd_traits_<256, ck_tile::bf16_t, true, 128, 128, 32, 256, 32, 256, true, ck_tile::BlockFmhaPipelineEnum::QRKSVS, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + return fmha_fwd_(s, a); + } + + } + + } + + return r; +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1cbf88db44aa5f884438288a325270d29c7a04b6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1cbf88db44aa5f884438288a325270d29c7a04b6.hip new file mode 100644 index 000000000000..f89fcd9026d6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1cbf88db44aa5f884438288a325270d29c7a04b6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1cc459e57bfed5ec7f40ea4a4dd9f72f3ad7a709.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1cc459e57bfed5ec7f40ea4a4dd9f72f3ad7a709.hip new file mode 100644 index 000000000000..a7b6ac361d10 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1cc459e57bfed5ec7f40ea4a4dd9f72f3ad7a709.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d02609fb803ea2697e2c2cef35e6f923d2578cf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d02609fb803ea2697e2c2cef35e6f923d2578cf.hip new file mode 100644 index 000000000000..26c89e012916 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d02609fb803ea2697e2c2cef35e6f923d2578cf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d0b822743e0205f60521d38d7c64f589fdf0f58.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d0b822743e0205f60521d38d7c64f589fdf0f58.hip new file mode 100644 index 000000000000..a8b77d2a4e2c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d0b822743e0205f60521d38d7c64f589fdf0f58.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d21263e16dafe79b9fe2f998847296e575c14e7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d21263e16dafe79b9fe2f998847296e575c14e7.hip new file mode 100644 index 000000000000..5f35d549931e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d21263e16dafe79b9fe2f998847296e575c14e7.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d3ef3d5ded0dfe2a0bafb52ea8f841658db35fd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d3ef3d5ded0dfe2a0bafb52ea8f841658db35fd.hip new file mode 100644 index 000000000000..132239fd081b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d3ef3d5ded0dfe2a0bafb52ea8f841658db35fd.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d498e418ebbf33bed58b4074d1edf3d9bdd07c5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d498e418ebbf33bed58b4074d1edf3d9bdd07c5.hip new file mode 100644 index 000000000000..fa10992280fc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1d498e418ebbf33bed58b4074d1edf3d9bdd07c5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1da23de9604b5d98fe02529075bad995954c12ca.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1da23de9604b5d98fe02529075bad995954c12ca.hip new file mode 100644 index 000000000000..9966b0db808a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1da23de9604b5d98fe02529075bad995954c12ca.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1db03461737f1e359f389a8d297476f9b60faabd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1db03461737f1e359f389a8d297476f9b60faabd.hip new file mode 100644 index 000000000000..74dcf07cfb16 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1db03461737f1e359f389a8d297476f9b60faabd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1dc6e599144a093203fd7f92ac6d3c2cd7180d49.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1dc6e599144a093203fd7f92ac6d3c2cd7180d49.hip new file mode 100644 index 000000000000..8840038e5784 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1dc6e599144a093203fd7f92ac6d3c2cd7180d49.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1de2f97d49f015b9af0b186801e939c6f357a0c4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1de2f97d49f015b9af0b186801e939c6f357a0c4.hip new file mode 100644 index 000000000000..ff58791e85f8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1de2f97d49f015b9af0b186801e939c6f357a0c4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1df893ee660d37fba7eaca452ae65b3e45a73087.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1df893ee660d37fba7eaca452ae65b3e45a73087.hip new file mode 100644 index 000000000000..696d403eda66 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1df893ee660d37fba7eaca452ae65b3e45a73087.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e22f2d99804198c61251b4629a3f18ed3dcd42e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e22f2d99804198c61251b4629a3f18ed3dcd42e.hip new file mode 100644 index 000000000000..8590d328eb8b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e22f2d99804198c61251b4629a3f18ed3dcd42e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e33ce1fa113b221e5303b4093c2c4e748ce8298.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e33ce1fa113b221e5303b4093c2c4e748ce8298.hip new file mode 100644 index 000000000000..75d061fed3da --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e33ce1fa113b221e5303b4093c2c4e748ce8298.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e42736d4f677a59a172bd6f162616a437696351.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e42736d4f677a59a172bd6f162616a437696351.hip new file mode 100644 index 000000000000..ef6b29cf938f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e42736d4f677a59a172bd6f162616a437696351.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e7d7888480b83c78833214b32e10f37a6e20301.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e7d7888480b83c78833214b32e10f37a6e20301.hip new file mode 100644 index 000000000000..9869c96ce0eb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e7d7888480b83c78833214b32e10f37a6e20301.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + false, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e9130607a2d24cb0662a47e9cf12c6602143838.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e9130607a2d24cb0662a47e9cf12c6602143838.hip new file mode 100644 index 000000000000..6a5fef08dcb2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e9130607a2d24cb0662a47e9cf12c6602143838.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e943fcc2e64c618fc1415b3f1a0db4d70aa8494.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e943fcc2e64c618fc1415b3f1a0db4d70aa8494.hip new file mode 100644 index 000000000000..9f0ab6075909 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1e943fcc2e64c618fc1415b3f1a0db4d70aa8494.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1edaf9d4270d2ac61c299320e06ba73f44730364.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1edaf9d4270d2ac61c299320e06ba73f44730364.hip new file mode 100644 index 000000000000..fdbfe7482c95 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1edaf9d4270d2ac61c299320e06ba73f44730364.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f0cad6ad5b172e51c569e84cd54a19b4eb0ed05.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f0cad6ad5b172e51c569e84cd54a19b4eb0ed05.hip new file mode 100644 index 000000000000..4c3c7d43b9e6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f0cad6ad5b172e51c569e84cd54a19b4eb0ed05.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f13a6d0f8c798c0c4ba4ad202d081899fe081ab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f13a6d0f8c798c0c4ba4ad202d081899fe081ab.hip new file mode 100644 index 000000000000..2982627aa80d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f13a6d0f8c798c0c4ba4ad202d081899fe081ab.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f6bc5faf18be193212217788d476ce6fd384bfb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f6bc5faf18be193212217788d476ce6fd384bfb.hip new file mode 100644 index 000000000000..25a407fd40f5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f6bc5faf18be193212217788d476ce6fd384bfb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f7faa0b33a9aada86f032174afd40d18efa7715.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f7faa0b33a9aada86f032174afd40d18efa7715.hip new file mode 100644 index 000000000000..298a4e9b153b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f7faa0b33a9aada86f032174afd40d18efa7715.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f81f8cce0d77dec9f977b9eeb0778b70a13fa75.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f81f8cce0d77dec9f977b9eeb0778b70a13fa75.hip new file mode 100644 index 000000000000..23f4ba8fab55 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1f81f8cce0d77dec9f977b9eeb0778b70a13fa75.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1fcdcb750f382fc7828a9886585f50efbe5be735.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1fcdcb750f382fc7828a9886585f50efbe5be735.hip new file mode 100644 index 000000000000..790586338a16 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1fcdcb750f382fc7828a9886585f50efbe5be735.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1fd9fa7c2e13d0bad5fddb2b5a316bbc09d397ea.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1fd9fa7c2e13d0bad5fddb2b5a316bbc09d397ea.hip new file mode 100644 index 000000000000..7ee363221c27 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1fd9fa7c2e13d0bad5fddb2b5a316bbc09d397ea.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1fda1c96568eab89a8f6498f8bb23c1223cdc7b0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1fda1c96568eab89a8f6498f8bb23c1223cdc7b0.hip new file mode 100644 index 000000000000..cdc06572111c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_1fda1c96568eab89a8f6498f8bb23c1223cdc7b0.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2005aca3520b171bb82d10ad70fef44f28c19776.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2005aca3520b171bb82d10ad70fef44f28c19776.hip new file mode 100644 index 000000000000..3cb13e210426 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2005aca3520b171bb82d10ad70fef44f28c19776.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_204a573ce6b7d2f90aede543939315561cc43177.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_204a573ce6b7d2f90aede543939315561cc43177.hip new file mode 100644 index 000000000000..038304e4237d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_204a573ce6b7d2f90aede543939315561cc43177.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_20588bcac681a5d69f252d7523a3681a0c6b6181.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_20588bcac681a5d69f252d7523a3681a0c6b6181.hip new file mode 100644 index 000000000000..5f9a6cd1607b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_20588bcac681a5d69f252d7523a3681a0c6b6181.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2081430c92864c29bb9f409e7c27caee1de00749.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2081430c92864c29bb9f409e7c27caee1de00749.hip new file mode 100644 index 000000000000..a66950b1b4b8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2081430c92864c29bb9f409e7c27caee1de00749.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_20d5c3c86398f6ce55abc90db3e362dbf9f457f2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_20d5c3c86398f6ce55abc90db3e362dbf9f457f2.hip new file mode 100644 index 000000000000..05e5cc3248c6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_20d5c3c86398f6ce55abc90db3e362dbf9f457f2.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_20f7ea0aabd069362ba4bbd66623cea5b6e1a6bd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_20f7ea0aabd069362ba4bbd66623cea5b6e1a6bd.hip new file mode 100644 index 000000000000..e723a9c2bf46 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_20f7ea0aabd069362ba4bbd66623cea5b6e1a6bd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_210ef512b7862837f54acbc3b21e135a192647a3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_210ef512b7862837f54acbc3b21e135a192647a3.hip new file mode 100644 index 000000000000..9fcad67cae66 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_210ef512b7862837f54acbc3b21e135a192647a3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2122c973581930ab7a4ebc90b3bf1cdaa229a87f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2122c973581930ab7a4ebc90b3bf1cdaa229a87f.hip new file mode 100644 index 000000000000..dab9ce44e653 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2122c973581930ab7a4ebc90b3bf1cdaa229a87f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21411df58165946bf02942b597d94de7dd856987.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21411df58165946bf02942b597d94de7dd856987.hip new file mode 100644 index 000000000000..4c82e0eed5b2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21411df58165946bf02942b597d94de7dd856987.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_216806a4598c885e517e664fc8280c59ec3cbf11.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_216806a4598c885e517e664fc8280c59ec3cbf11.hip new file mode 100644 index 000000000000..9370d29e8052 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_216806a4598c885e517e664fc8280c59ec3cbf11.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2173b7c710d418f44dc2b41bec5905024334eae5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2173b7c710d418f44dc2b41bec5905024334eae5.hip new file mode 100644 index 000000000000..80f87ac9f3a9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2173b7c710d418f44dc2b41bec5905024334eae5.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2177d95cdf45f6fec95d1812f2ef183a75259e38.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2177d95cdf45f6fec95d1812f2ef183a75259e38.hip new file mode 100644 index 000000000000..720880cc4ab9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2177d95cdf45f6fec95d1812f2ef183a75259e38.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21828c7d3f5574690f12f841c27f025206e6165b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21828c7d3f5574690f12f841c27f025206e6165b.hip new file mode 100644 index 000000000000..af05c411c24f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21828c7d3f5574690f12f841c27f025206e6165b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2184fba2eec5899bb40d49d4508196e6be1ec1b1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2184fba2eec5899bb40d49d4508196e6be1ec1b1.hip new file mode 100644 index 000000000000..3b30ec85c774 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2184fba2eec5899bb40d49d4508196e6be1ec1b1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21e235e31d6955393ac8e825bd69ead70687b7c8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21e235e31d6955393ac8e825bd69ead70687b7c8.hip new file mode 100644 index 000000000000..7e33230130a0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21e235e31d6955393ac8e825bd69ead70687b7c8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21f860d42fdc2cc6bd743d53ba546e332c22fedf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21f860d42fdc2cc6bd743d53ba546e332c22fedf.hip new file mode 100644 index 000000000000..a3b06ae0a85f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_21f860d42fdc2cc6bd743d53ba546e332c22fedf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22105635385fbfb5d2f330df83ba6747bcb27f6d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22105635385fbfb5d2f330df83ba6747bcb27f6d.hip new file mode 100644 index 000000000000..a859ad0a5f04 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22105635385fbfb5d2f330df83ba6747bcb27f6d.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_224f9af5e5ca519b21b71a54acb49f50b4999c47.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_224f9af5e5ca519b21b71a54acb49f50b4999c47.hip new file mode 100644 index 000000000000..2fadab0c4c76 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_224f9af5e5ca519b21b71a54acb49f50b4999c47.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22511de2592b6e350737e44865e1fed6496e3f32.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22511de2592b6e350737e44865e1fed6496e3f32.hip new file mode 100644 index 000000000000..d12086cc2a2f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22511de2592b6e350737e44865e1fed6496e3f32.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22632f996eb63fbe4bc5748c5897b775087446a0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22632f996eb63fbe4bc5748c5897b775087446a0.hip new file mode 100644 index 000000000000..1412dad1cb3c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22632f996eb63fbe4bc5748c5897b775087446a0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_226662cf1c9900a4334d2cadcc5f5ac3ad355f05.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_226662cf1c9900a4334d2cadcc5f5ac3ad355f05.hip new file mode 100644 index 000000000000..f66e9193a3ae --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_226662cf1c9900a4334d2cadcc5f5ac3ad355f05.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2273457ac3be01cc1595a015a5f598f8290c77e4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2273457ac3be01cc1595a015a5f598f8290c77e4.hip new file mode 100644 index 000000000000..0cff307021e6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2273457ac3be01cc1595a015a5f598f8290c77e4.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22a07ecf1a59f72ec6bef3e970d7f33cf54c5f44.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22a07ecf1a59f72ec6bef3e970d7f33cf54c5f44.hip new file mode 100644 index 000000000000..eb66c2a58e65 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22a07ecf1a59f72ec6bef3e970d7f33cf54c5f44.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22c142d869ef940ca876c93033ad53b576ed34f2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22c142d869ef940ca876c93033ad53b576ed34f2.hip new file mode 100644 index 000000000000..14c3b0eb0166 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_22c142d869ef940ca876c93033ad53b576ed34f2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_23047ea90076e3b0a3eb0586d49b9ee74ca6d279.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_23047ea90076e3b0a3eb0586d49b9ee74ca6d279.hip new file mode 100644 index 000000000000..be2f64ccab68 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_23047ea90076e3b0a3eb0586d49b9ee74ca6d279.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_230861e81e5acc523fa680534eed757b7b4a4e1d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_230861e81e5acc523fa680534eed757b7b4a4e1d.hip new file mode 100644 index 000000000000..d1c5a3fa9246 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_230861e81e5acc523fa680534eed757b7b4a4e1d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_232f61bf31dbb5de5d7039d5ff2338068a759b68.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_232f61bf31dbb5de5d7039d5ff2338068a759b68.hip new file mode 100644 index 000000000000..d26a92e4ecaa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_232f61bf31dbb5de5d7039d5ff2338068a759b68.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_233132e712eba8972ba444c604f89e01c5b84cc0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_233132e712eba8972ba444c604f89e01c5b84cc0.hip new file mode 100644 index 000000000000..6d900a17bfa8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_233132e712eba8972ba444c604f89e01c5b84cc0.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_235bf652702c2976551778b9159e09188575c63c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_235bf652702c2976551778b9159e09188575c63c.hip new file mode 100644 index 000000000000..c42575bbb159 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_235bf652702c2976551778b9159e09188575c63c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_236b3eef02b904304348b9d35f715b639d63218f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_236b3eef02b904304348b9d35f715b639d63218f.hip new file mode 100644 index 000000000000..e78ae2da9431 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_236b3eef02b904304348b9d35f715b639d63218f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_238e4c1ca112afec494fbe47a85b553302c43395.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_238e4c1ca112afec494fbe47a85b553302c43395.hip new file mode 100644 index 000000000000..a55761ed0585 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_238e4c1ca112afec494fbe47a85b553302c43395.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_23914c00690ac5c4f89cdbbaf00732ba66c5c0ef.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_23914c00690ac5c4f89cdbbaf00732ba66c5c0ef.hip new file mode 100644 index 000000000000..f86fd21a3de6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_23914c00690ac5c4f89cdbbaf00732ba66c5c0ef.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_23c9b46da8774462de8c24e14b12df3ed596eb57.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_23c9b46da8774462de8c24e14b12df3ed596eb57.hip new file mode 100644 index 000000000000..684f699a20a5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_23c9b46da8774462de8c24e14b12df3ed596eb57.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_242013527a0266ad479715ee3e6ae01c45de29d0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_242013527a0266ad479715ee3e6ae01c45de29d0.hip new file mode 100644 index 000000000000..ec3da5375c0f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_242013527a0266ad479715ee3e6ae01c45de29d0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_24410fd9a4150c33186a2a365d06d8f6ea621c20.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_24410fd9a4150c33186a2a365d06d8f6ea621c20.hip new file mode 100644 index 000000000000..e57b6b2464d5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_24410fd9a4150c33186a2a365d06d8f6ea621c20.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_245d90000b55ab8b6055b1934880fc6c4870b34b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_245d90000b55ab8b6055b1934880fc6c4870b34b.hip new file mode 100644 index 000000000000..2d7d691e854f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_245d90000b55ab8b6055b1934880fc6c4870b34b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_24643917fc970c043d1c80d8d4b17ec92deeb8a1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_24643917fc970c043d1c80d8d4b17ec92deeb8a1.hip new file mode 100644 index 000000000000..9dcaec3e6ec3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_24643917fc970c043d1c80d8d4b17ec92deeb8a1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_249668a3212cd00edaae871758be30a5a1fea589.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_249668a3212cd00edaae871758be30a5a1fea589.hip new file mode 100644 index 000000000000..c8befdad639e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_249668a3212cd00edaae871758be30a5a1fea589.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_249e6b93baae25dff97a0bc9145a8d328ed3f317.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_249e6b93baae25dff97a0bc9145a8d328ed3f317.hip new file mode 100644 index 000000000000..f385d543ec45 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_249e6b93baae25dff97a0bc9145a8d328ed3f317.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2543da478310245e19e6c6a0d9ed7ad99540b3bc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2543da478310245e19e6c6a0d9ed7ad99540b3bc.hip new file mode 100644 index 000000000000..171cdac4f03e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2543da478310245e19e6c6a0d9ed7ad99540b3bc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_256ef175029a43e64164176d4eb212baf9d27bb9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_256ef175029a43e64164176d4eb212baf9d27bb9.hip new file mode 100644 index 000000000000..fbed56c47d96 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_256ef175029a43e64164176d4eb212baf9d27bb9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_258d747083272ea657604ac84867ecea17bd65da.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_258d747083272ea657604ac84867ecea17bd65da.hip new file mode 100644 index 000000000000..89bc7dd7e2ba --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_258d747083272ea657604ac84867ecea17bd65da.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_25938733446b6c0dcd159719f08d04a9aa467967.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_25938733446b6c0dcd159719f08d04a9aa467967.hip new file mode 100644 index 000000000000..a4af048e43f7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_25938733446b6c0dcd159719f08d04a9aa467967.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_25b3225da1e1842f83592971a1f62a0fe30aa9d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_25b3225da1e1842f83592971a1f62a0fe30aa9d3.hip new file mode 100644 index 000000000000..525dc28431dd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_25b3225da1e1842f83592971a1f62a0fe30aa9d3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2660282ad39ef034fecbdb74acedfb48620b7dfd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2660282ad39ef034fecbdb74acedfb48620b7dfd.hip new file mode 100644 index 000000000000..f6d4f632d33a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2660282ad39ef034fecbdb74acedfb48620b7dfd.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26835ba70606c769e56d19dbfe74061361aa855e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26835ba70606c769e56d19dbfe74061361aa855e.hip new file mode 100644 index 000000000000..30ba2020a4eb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26835ba70606c769e56d19dbfe74061361aa855e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2695783ae8f0034692efd6563f789ef03fd0f4f3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2695783ae8f0034692efd6563f789ef03fd0f4f3.hip new file mode 100644 index 000000000000..af5313e1cab0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2695783ae8f0034692efd6563f789ef03fd0f4f3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26d77b228420a3ead919474ec9c6fb2800f86890.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26d77b228420a3ead919474ec9c6fb2800f86890.hip new file mode 100644 index 000000000000..4544758fe65b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26d77b228420a3ead919474ec9c6fb2800f86890.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26ea90eb5a527434c1740933a1d2dd863eccf14c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26ea90eb5a527434c1740933a1d2dd863eccf14c.hip new file mode 100644 index 000000000000..9bfdb6aef825 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26ea90eb5a527434c1740933a1d2dd863eccf14c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26f90358e522d7bb7c76c3a2c6010f0f38788bb6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26f90358e522d7bb7c76c3a2c6010f0f38788bb6.hip new file mode 100644 index 000000000000..95c8658d8057 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_26f90358e522d7bb7c76c3a2c6010f0f38788bb6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2703018e71d57d3266fc35e2e18a78faa3dd52ce.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2703018e71d57d3266fc35e2e18a78faa3dd52ce.hip new file mode 100644 index 000000000000..a94494a3e4e6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2703018e71d57d3266fc35e2e18a78faa3dd52ce.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_278639d44a4a8372a627a7c31e9527c8faa26f97.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_278639d44a4a8372a627a7c31e9527c8faa26f97.hip new file mode 100644 index 000000000000..40543edbadb1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_278639d44a4a8372a627a7c31e9527c8faa26f97.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_27c2000d32c230a57a6712f27bc0fba02722f5fd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_27c2000d32c230a57a6712f27bc0fba02722f5fd.hip new file mode 100644 index 000000000000..06f2ffea70c5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_27c2000d32c230a57a6712f27bc0fba02722f5fd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_280bfced8745fbd9266207463fb41476dc23afff.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_280bfced8745fbd9266207463fb41476dc23afff.hip new file mode 100644 index 000000000000..2e857611381b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_280bfced8745fbd9266207463fb41476dc23afff.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_281d897ad17d7f6db2741b396e6b85a9b8f35286.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_281d897ad17d7f6db2741b396e6b85a9b8f35286.hip new file mode 100644 index 000000000000..c4a5d12c6e6b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_281d897ad17d7f6db2741b396e6b85a9b8f35286.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_285e61dad8f63fb973cb2eb899c959e400622652.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_285e61dad8f63fb973cb2eb899c959e400622652.hip new file mode 100644 index 000000000000..38d2df11bb8c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_285e61dad8f63fb973cb2eb899c959e400622652.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_288458c5a0720ef152848713119ebce6d76db6d6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_288458c5a0720ef152848713119ebce6d76db6d6.hip new file mode 100644 index 000000000000..c0655b73a180 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_288458c5a0720ef152848713119ebce6d76db6d6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_289071756e7d0582eb61ce6483fa3c988d2e10b5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_289071756e7d0582eb61ce6483fa3c988d2e10b5.hip new file mode 100644 index 000000000000..7413b4f6e864 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_289071756e7d0582eb61ce6483fa3c988d2e10b5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28e4d2c757e4b8c366a2c320360e21ff0ef671a8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28e4d2c757e4b8c366a2c320360e21ff0ef671a8.hip new file mode 100644 index 000000000000..baca96ca671e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28e4d2c757e4b8c366a2c320360e21ff0ef671a8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28f1ef32c4384ec26f3dc5e3af6a74fc8cebae92.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28f1ef32c4384ec26f3dc5e3af6a74fc8cebae92.hip new file mode 100644 index 000000000000..eae16006f2b9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28f1ef32c4384ec26f3dc5e3af6a74fc8cebae92.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28f2e2b108a53308a0cb6c123c8d318cbc2eadb4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28f2e2b108a53308a0cb6c123c8d318cbc2eadb4.hip new file mode 100644 index 000000000000..de4bee27c0d0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28f2e2b108a53308a0cb6c123c8d318cbc2eadb4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28f7634d29bef11fd466b452a46b0612f38c949b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28f7634d29bef11fd466b452a46b0612f38c949b.hip new file mode 100644 index 000000000000..541d95b98042 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_28f7634d29bef11fd466b452a46b0612f38c949b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_290c484c2a366258941ee0051e139ea716a9de2f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_290c484c2a366258941ee0051e139ea716a9de2f.hip new file mode 100644 index 000000000000..6d221ce6c9e6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_290c484c2a366258941ee0051e139ea716a9de2f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_291a8bdf9d63b112e7fe5fa7e8835a6789cb8ecf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_291a8bdf9d63b112e7fe5fa7e8835a6789cb8ecf.hip new file mode 100644 index 000000000000..90e416b7df93 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_291a8bdf9d63b112e7fe5fa7e8835a6789cb8ecf.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_292454f2d82184ab0491ea0675750c6ec55d659c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_292454f2d82184ab0491ea0675750c6ec55d659c.hip new file mode 100644 index 000000000000..82e69e6b972b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_292454f2d82184ab0491ea0675750c6ec55d659c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_292b4f995d622826af5d1f2bffa7ba68467c841a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_292b4f995d622826af5d1f2bffa7ba68467c841a.hip new file mode 100644 index 000000000000..644c096be643 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_292b4f995d622826af5d1f2bffa7ba68467c841a.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_295a523f815eb822d66162d4feb75fe0bc50b648.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_295a523f815eb822d66162d4feb75fe0bc50b648.hip new file mode 100644 index 000000000000..f1e151bb0c27 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_295a523f815eb822d66162d4feb75fe0bc50b648.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_296c5836ba118969c4ba89ed62a98dffe3105738.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_296c5836ba118969c4ba89ed62a98dffe3105738.hip new file mode 100644 index 000000000000..6253aaf7b8dd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_296c5836ba118969c4ba89ed62a98dffe3105738.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + false, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2995d39cd62f20622a31f11a292ed175abb5fdf9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2995d39cd62f20622a31f11a292ed175abb5fdf9.hip new file mode 100644 index 000000000000..ea580c781dd7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2995d39cd62f20622a31f11a292ed175abb5fdf9.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_29bffc159b0bb826ba489ae763dae141bfe8e802.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_29bffc159b0bb826ba489ae763dae141bfe8e802.hip new file mode 100644 index 000000000000..7c2551e94868 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_29bffc159b0bb826ba489ae763dae141bfe8e802.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_29c9e5384809b21f39e78bb2e43af345a9a21d19.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_29c9e5384809b21f39e78bb2e43af345a9a21d19.hip new file mode 100644 index 000000000000..f138f8cfe799 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_29c9e5384809b21f39e78bb2e43af345a9a21d19.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_29fe68ba10b3480dddc9866c51ca8b5efe962cc3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_29fe68ba10b3480dddc9866c51ca8b5efe962cc3.hip new file mode 100644 index 000000000000..f80cf2616331 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_29fe68ba10b3480dddc9866c51ca8b5efe962cc3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a3a980a26682d879c3a3425f3ba5be3f5761adf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a3a980a26682d879c3a3425f3ba5be3f5761adf.hip new file mode 100644 index 000000000000..debe3536bf67 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a3a980a26682d879c3a3425f3ba5be3f5761adf.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a45129fc4995abcb8f880692f11c6186fc01641.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a45129fc4995abcb8f880692f11c6186fc01641.hip new file mode 100644 index 000000000000..7ae88583ad53 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a45129fc4995abcb8f880692f11c6186fc01641.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a833fc01e88bd8e256ef64ae8251dd0ed10720b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a833fc01e88bd8e256ef64ae8251dd0ed10720b.hip new file mode 100644 index 000000000000..52c67012d4f7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a833fc01e88bd8e256ef64ae8251dd0ed10720b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a97c457144cb63a9c6c3d6be613b47bd0df9928.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a97c457144cb63a9c6c3d6be613b47bd0df9928.hip new file mode 100644 index 000000000000..0ae1cf1c5cfd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2a97c457144cb63a9c6c3d6be613b47bd0df9928.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ad492377add5c8f6d0d2dbf9ee9e4338bbd9f1f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ad492377add5c8f6d0d2dbf9ee9e4338bbd9f1f.hip new file mode 100644 index 000000000000..1ef34b4126bd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ad492377add5c8f6d0d2dbf9ee9e4338bbd9f1f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ae344010d49f7f9a6caab2cb84be7f87d2d96bf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ae344010d49f7f9a6caab2cb84be7f87d2d96bf.hip new file mode 100644 index 000000000000..62ac218caad3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ae344010d49f7f9a6caab2cb84be7f87d2d96bf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2af6c5be53732eb1939a2f93232af7dc011dec1a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2af6c5be53732eb1939a2f93232af7dc011dec1a.hip new file mode 100644 index 000000000000..5417d1145fdf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2af6c5be53732eb1939a2f93232af7dc011dec1a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b0bcb241e5a1be1d35366461408d06e095a26ef.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b0bcb241e5a1be1d35366461408d06e095a26ef.hip new file mode 100644 index 000000000000..9537b60aadd7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b0bcb241e5a1be1d35366461408d06e095a26ef.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b3326e055da32cc979892a2fbd0f7b003cb9f98.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b3326e055da32cc979892a2fbd0f7b003cb9f98.hip new file mode 100644 index 000000000000..991da6f6a43b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b3326e055da32cc979892a2fbd0f7b003cb9f98.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + true, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b3af90387f1d227119c5dcd4b71362940bbce52.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b3af90387f1d227119c5dcd4b71362940bbce52.hip new file mode 100644 index 000000000000..32988018b506 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b3af90387f1d227119c5dcd4b71362940bbce52.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b4050988e5790a28dbe10b4c20e14f10f6cf85c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b4050988e5790a28dbe10b4c20e14f10f6cf85c.hip new file mode 100644 index 000000000000..ad3e20331fd2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b4050988e5790a28dbe10b4c20e14f10f6cf85c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b49a9b0801a06dd89c7f7182d7590b515df1592.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b49a9b0801a06dd89c7f7182d7590b515df1592.hip new file mode 100644 index 000000000000..244458feef1c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b49a9b0801a06dd89c7f7182d7590b515df1592.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b50073f6dfeb7ea77d5dce288a1d2f08f8f6362.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b50073f6dfeb7ea77d5dce288a1d2f08f8f6362.hip new file mode 100644 index 000000000000..8c4e06043bf1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b50073f6dfeb7ea77d5dce288a1d2f08f8f6362.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b5317b6cde327a842170ebff20c2b03d81379ff.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b5317b6cde327a842170ebff20c2b03d81379ff.hip new file mode 100644 index 000000000000..36cc57f19a7c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b5317b6cde327a842170ebff20c2b03d81379ff.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b8169ce4b4b9a17ac96fbb232e6a93f22071ab4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b8169ce4b4b9a17ac96fbb232e6a93f22071ab4.hip new file mode 100644 index 000000000000..a95c1e56edac --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b8169ce4b4b9a17ac96fbb232e6a93f22071ab4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b823c3b99e7c8d1cdc39a5dbc7365a383bf9ccb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b823c3b99e7c8d1cdc39a5dbc7365a383bf9ccb.hip new file mode 100644 index 000000000000..2d54c326f1a4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2b823c3b99e7c8d1cdc39a5dbc7365a383bf9ccb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ba934408c75da5479cc41f96b98ea7d333635ea.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ba934408c75da5479cc41f96b98ea7d333635ea.hip new file mode 100644 index 000000000000..e2166d14ee91 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ba934408c75da5479cc41f96b98ea7d333635ea.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2bb6da1095bd8669c0e48b5cd808cf0dcefa2674.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2bb6da1095bd8669c0e48b5cd808cf0dcefa2674.hip new file mode 100644 index 000000000000..4f11e3066c6c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2bb6da1095bd8669c0e48b5cd808cf0dcefa2674.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c0bda0feaade2b554d648d72f219ac9c389bf09.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c0bda0feaade2b554d648d72f219ac9c389bf09.hip new file mode 100644 index 000000000000..66f2134b17be --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c0bda0feaade2b554d648d72f219ac9c389bf09.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c2e75e6f659a500dd3cf2cfd65118f111342119.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c2e75e6f659a500dd3cf2cfd65118f111342119.hip new file mode 100644 index 000000000000..81ecee533c2d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c2e75e6f659a500dd3cf2cfd65118f111342119.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c77bd7e89ed832cc31b2995566a49bec6e4cb52.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c77bd7e89ed832cc31b2995566a49bec6e4cb52.hip new file mode 100644 index 000000000000..1465eede4a1a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c77bd7e89ed832cc31b2995566a49bec6e4cb52.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c7aede7762a524a7a424cc4dc46e43fdedf73a2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c7aede7762a524a7a424cc4dc46e43fdedf73a2.hip new file mode 100644 index 000000000000..7c747ffa8525 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c7aede7762a524a7a424cc4dc46e43fdedf73a2.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c808da5c2514806c2953bb77d5692e5d7c97aa3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c808da5c2514806c2953bb77d5692e5d7c97aa3.hip new file mode 100644 index 000000000000..7bfcf3546a23 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c808da5c2514806c2953bb77d5692e5d7c97aa3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c82e3c4e445e1e02f14435e4ca01a90850139a4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c82e3c4e445e1e02f14435e4ca01a90850139a4.hip new file mode 100644 index 000000000000..7041dde51e5f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c82e3c4e445e1e02f14435e4ca01a90850139a4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c9756060ac0e73dbcfc58a9222a78f0283cd029.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c9756060ac0e73dbcfc58a9222a78f0283cd029.hip new file mode 100644 index 000000000000..653c7f1a9477 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2c9756060ac0e73dbcfc58a9222a78f0283cd029.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2caba3ab83239e474412fcf89fe0fbef97e51bf1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2caba3ab83239e474412fcf89fe0fbef97e51bf1.hip new file mode 100644 index 000000000000..997e311f175f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2caba3ab83239e474412fcf89fe0fbef97e51bf1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2cf351fc2c2da4a8e1760a3affc9a5947c6b3bda.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2cf351fc2c2da4a8e1760a3affc9a5947c6b3bda.hip new file mode 100644 index 000000000000..f76fafe72e78 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2cf351fc2c2da4a8e1760a3affc9a5947c6b3bda.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d06f77a4054ca615d96636c0e2eba2a89850142.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d06f77a4054ca615d96636c0e2eba2a89850142.hip new file mode 100644 index 000000000000..8ac1c3f37dcc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d06f77a4054ca615d96636c0e2eba2a89850142.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d1f2d1e57095f756ddd11e8e9d4f6f253e3ffa3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d1f2d1e57095f756ddd11e8e9d4f6f253e3ffa3.hip new file mode 100644 index 000000000000..d0f866850352 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d1f2d1e57095f756ddd11e8e9d4f6f253e3ffa3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d23a26e0a59a8323dd97632e610d24624143fbe.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d23a26e0a59a8323dd97632e610d24624143fbe.hip new file mode 100644 index 000000000000..da3fbc3f5842 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d23a26e0a59a8323dd97632e610d24624143fbe.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d43460c011b8d5e01ea98c9b8ddce962de59a96.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d43460c011b8d5e01ea98c9b8ddce962de59a96.hip new file mode 100644 index 000000000000..f4ce61655c8d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d43460c011b8d5e01ea98c9b8ddce962de59a96.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + false, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d446754d7000673779d15d3e73039fd3c10a720.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d446754d7000673779d15d3e73039fd3c10a720.hip new file mode 100644 index 000000000000..90d14935193c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d446754d7000673779d15d3e73039fd3c10a720.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d7b637e0313cb423b22cd8844cc2997b3ff73e4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d7b637e0313cb423b22cd8844cc2997b3ff73e4.hip new file mode 100644 index 000000000000..7e4018d6cadb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d7b637e0313cb423b22cd8844cc2997b3ff73e4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d9a04b7f41dd6f0db017157a44790f35c626e2d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d9a04b7f41dd6f0db017157a44790f35c626e2d.hip new file mode 100644 index 000000000000..1980cd16e4f9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d9a04b7f41dd6f0db017157a44790f35c626e2d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d9c659ba43bb907fd4e3e36a50958288bafd1a3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d9c659ba43bb907fd4e3e36a50958288bafd1a3.hip new file mode 100644 index 000000000000..dc9e3df1cb40 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2d9c659ba43bb907fd4e3e36a50958288bafd1a3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2da2b905c4ce32234c2af62328adae6b1f9217a8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2da2b905c4ce32234c2af62328adae6b1f9217a8.hip new file mode 100644 index 000000000000..ac60af950a37 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2da2b905c4ce32234c2af62328adae6b1f9217a8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2db33b5442d2e0948762b1f2147a321a9d6907be.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2db33b5442d2e0948762b1f2147a321a9d6907be.hip new file mode 100644 index 000000000000..4521da9fba51 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2db33b5442d2e0948762b1f2147a321a9d6907be.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2dfac5a83def98340c8786d55a30a98ad68b9eed.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2dfac5a83def98340c8786d55a30a98ad68b9eed.hip new file mode 100644 index 000000000000..a3448cf7d456 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2dfac5a83def98340c8786d55a30a98ad68b9eed.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2e30f50071113dc4ab59468d568ac9deb06b0342.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2e30f50071113dc4ab59468d568ac9deb06b0342.hip new file mode 100644 index 000000000000..ffabb7038de9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2e30f50071113dc4ab59468d568ac9deb06b0342.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2e43e401abbfb1b6737e4dc822f68421abbc648a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2e43e401abbfb1b6737e4dc822f68421abbc648a.hip new file mode 100644 index 000000000000..8b920baad7d2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2e43e401abbfb1b6737e4dc822f68421abbc648a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2e8b4260626beeac76c26dbcee3cba1457b30e99.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2e8b4260626beeac76c26dbcee3cba1457b30e99.hip new file mode 100644 index 000000000000..989fcedea2a2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2e8b4260626beeac76c26dbcee3cba1457b30e99.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ea394a09c8691a534ad2219bedf73724b6dd5ce.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ea394a09c8691a534ad2219bedf73724b6dd5ce.hip new file mode 100644 index 000000000000..1a3a6e656bc7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2ea394a09c8691a534ad2219bedf73724b6dd5ce.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2eba937ff6d0302ab013db7349d4feb914107f1f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2eba937ff6d0302ab013db7349d4feb914107f1f.hip new file mode 100644 index 000000000000..902e50fffcb9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2eba937ff6d0302ab013db7349d4feb914107f1f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2f0247e301a7b076b6ec8a778c3b47e330638963.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2f0247e301a7b076b6ec8a778c3b47e330638963.hip new file mode 100644 index 000000000000..eaa36741dfca --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2f0247e301a7b076b6ec8a778c3b47e330638963.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2f32f2d658f1f69840fbad511ce8a3851c859d52.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2f32f2d658f1f69840fbad511ce8a3851c859d52.hip new file mode 100644 index 000000000000..9b8ea0448449 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2f32f2d658f1f69840fbad511ce8a3851c859d52.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2f55a23a0f24ff7062a4c286944f25d2db3e20a4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2f55a23a0f24ff7062a4c286944f25d2db3e20a4.hip new file mode 100644 index 000000000000..b5a1f73e4066 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_2f55a23a0f24ff7062a4c286944f25d2db3e20a4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_30024440e780fdf9ec94deccc85216d8bbb5788a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_30024440e780fdf9ec94deccc85216d8bbb5788a.hip new file mode 100644 index 000000000000..4d706362bb13 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_30024440e780fdf9ec94deccc85216d8bbb5788a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_303b7b04496e4db7c1ba2436485dc7c8a4c88448.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_303b7b04496e4db7c1ba2436485dc7c8a4c88448.hip new file mode 100644 index 000000000000..4e73287a50fd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_303b7b04496e4db7c1ba2436485dc7c8a4c88448.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3076a6de0e2612279e0ed64612f7393856bcc9ac.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3076a6de0e2612279e0ed64612f7393856bcc9ac.hip new file mode 100644 index 000000000000..dc18022abcf0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3076a6de0e2612279e0ed64612f7393856bcc9ac.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_30c8e4d5c761fda50e010da779e8e4730051d403.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_30c8e4d5c761fda50e010da779e8e4730051d403.hip new file mode 100644 index 000000000000..accc3bf3513f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_30c8e4d5c761fda50e010da779e8e4730051d403.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_30f0200092b0e18d57a9f5e512d565f1c0229436.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_30f0200092b0e18d57a9f5e512d565f1c0229436.hip new file mode 100644 index 000000000000..41dad6deb6bd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_30f0200092b0e18d57a9f5e512d565f1c0229436.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3108502fd29d3a24b32177bcea968121ee809115.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3108502fd29d3a24b32177bcea968121ee809115.hip new file mode 100644 index 000000000000..19a0b7459767 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3108502fd29d3a24b32177bcea968121ee809115.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3110540b50e95e99a5cccebe47d9d3a83093c2fb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3110540b50e95e99a5cccebe47d9d3a83093c2fb.hip new file mode 100644 index 000000000000..3a7c3a96d146 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3110540b50e95e99a5cccebe47d9d3a83093c2fb.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_311104394c8bef8d4ecff35c1409221e723a5a8a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_311104394c8bef8d4ecff35c1409221e723a5a8a.hip new file mode 100644 index 000000000000..595552c2eb4a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_311104394c8bef8d4ecff35c1409221e723a5a8a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_311731442b756308c0a869f21b7b8b103aa613e8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_311731442b756308c0a869f21b7b8b103aa613e8.hip new file mode 100644 index 000000000000..7d8501d7ac30 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_311731442b756308c0a869f21b7b8b103aa613e8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31222e158484773d2257f4a31e3dfbdb68336a8e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31222e158484773d2257f4a31e3dfbdb68336a8e.hip new file mode 100644 index 000000000000..1e69eeab2521 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31222e158484773d2257f4a31e3dfbdb68336a8e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3163272d25bc2db2ffaa1fea87648b45ee68d408.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3163272d25bc2db2ffaa1fea87648b45ee68d408.hip new file mode 100644 index 000000000000..3a4bf5f9985e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3163272d25bc2db2ffaa1fea87648b45ee68d408.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_319df310195191895005b30151da8c1afab6c82f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_319df310195191895005b30151da8c1afab6c82f.hip new file mode 100644 index 000000000000..f1e5c3091e19 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_319df310195191895005b30151da8c1afab6c82f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31a968898f0bc6366313e41eddb5e3a3ed12dc98.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31a968898f0bc6366313e41eddb5e3a3ed12dc98.hip new file mode 100644 index 000000000000..3dc985e15d1c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31a968898f0bc6366313e41eddb5e3a3ed12dc98.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31b807c48c472e9b1311a6037cd98e21d6706889.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31b807c48c472e9b1311a6037cd98e21d6706889.hip new file mode 100644 index 000000000000..5656854a8dce --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31b807c48c472e9b1311a6037cd98e21d6706889.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31c3760f5978baf9780ce4587ae4c768af0e49d1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31c3760f5978baf9780ce4587ae4c768af0e49d1.hip new file mode 100644 index 000000000000..b52b67fa2fdf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31c3760f5978baf9780ce4587ae4c768af0e49d1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31c4b866692ba5c3d115482bef4790733863c1fc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31c4b866692ba5c3d115482bef4790733863c1fc.hip new file mode 100644 index 000000000000..cc6c0f0fbad3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_31c4b866692ba5c3d115482bef4790733863c1fc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3206cc121ce8955ed59ea3b12b858ee2e0cf82f8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3206cc121ce8955ed59ea3b12b858ee2e0cf82f8.hip new file mode 100644 index 000000000000..32c52d79a318 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3206cc121ce8955ed59ea3b12b858ee2e0cf82f8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_320a6196b662a1d3dc7441a9536d825dc356b95d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_320a6196b662a1d3dc7441a9536d825dc356b95d.hip new file mode 100644 index 000000000000..c8f8761d6e21 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_320a6196b662a1d3dc7441a9536d825dc356b95d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_321500dd4c41e4d68834814a48a639f5ca36a2fb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_321500dd4c41e4d68834814a48a639f5ca36a2fb.hip new file mode 100644 index 000000000000..2e522d847bf1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_321500dd4c41e4d68834814a48a639f5ca36a2fb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_322a86568f89a5a5a165cfffbae9ca6949f2477e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_322a86568f89a5a5a165cfffbae9ca6949f2477e.hip new file mode 100644 index 000000000000..878cbe7e677c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_322a86568f89a5a5a165cfffbae9ca6949f2477e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_32438250078ba2a47345ec4955dafb4e4de78a25.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_32438250078ba2a47345ec4955dafb4e4de78a25.hip new file mode 100644 index 000000000000..92b96688e2ba --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_32438250078ba2a47345ec4955dafb4e4de78a25.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_32527660fa7aeb9a951a9f2fc3c53989bd141c48.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_32527660fa7aeb9a951a9f2fc3c53989bd141c48.hip new file mode 100644 index 000000000000..7d3ca38f3521 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_32527660fa7aeb9a951a9f2fc3c53989bd141c48.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_325fbcb9e503e68fafea08abf86a4951f440850f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_325fbcb9e503e68fafea08abf86a4951f440850f.hip new file mode 100644 index 000000000000..8815c227c29d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_325fbcb9e503e68fafea08abf86a4951f440850f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_32652a27e8605cef59c8341813b68e7513be23c5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_32652a27e8605cef59c8341813b68e7513be23c5.hip new file mode 100644 index 000000000000..2e1046225d2c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_32652a27e8605cef59c8341813b68e7513be23c5.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_327e27892bc57f3dec0da24f94f2a483d6c9321b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_327e27892bc57f3dec0da24f94f2a483d6c9321b.hip new file mode 100644 index 000000000000..54bbc59eded6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_327e27892bc57f3dec0da24f94f2a483d6c9321b.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_328a311bafd1c153525393b252e4170f8aafb370.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_328a311bafd1c153525393b252e4170f8aafb370.hip new file mode 100644 index 000000000000..140bea9b3005 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_328a311bafd1c153525393b252e4170f8aafb370.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_33099fcfc218ffdf69edb4f2f0e46121bea9fafc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_33099fcfc218ffdf69edb4f2f0e46121bea9fafc.hip new file mode 100644 index 000000000000..73ca1ca886d7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_33099fcfc218ffdf69edb4f2f0e46121bea9fafc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_33746071156e9ad46f403a539dc237e0a44122a7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_33746071156e9ad46f403a539dc237e0a44122a7.hip new file mode 100644 index 000000000000..cf909ff4b49e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_33746071156e9ad46f403a539dc237e0a44122a7.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_33e7c1e5f41a451c7baff54f7238b220f1bdf8a1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_33e7c1e5f41a451c7baff54f7238b220f1bdf8a1.hip new file mode 100644 index 000000000000..0c33783352c1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_33e7c1e5f41a451c7baff54f7238b220f1bdf8a1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3400f0af03743dce328486f8fc805dd30bd6da31.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3400f0af03743dce328486f8fc805dd30bd6da31.hip new file mode 100644 index 000000000000..5bdb954c7464 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3400f0af03743dce328486f8fc805dd30bd6da31.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3408103188e27b3bc55dce0c1716c0b4d32d6494.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3408103188e27b3bc55dce0c1716c0b4d32d6494.hip new file mode 100644 index 000000000000..b434aed4fcac --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3408103188e27b3bc55dce0c1716c0b4d32d6494.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_342d29c85070f488a14b1915f948e5fd69019c99.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_342d29c85070f488a14b1915f948e5fd69019c99.hip new file mode 100644 index 000000000000..d68768627db5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_342d29c85070f488a14b1915f948e5fd69019c99.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_344932e2655d7b32704be8de9a63bbd8c3369f02.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_344932e2655d7b32704be8de9a63bbd8c3369f02.hip new file mode 100644 index 000000000000..51beb87a2b3c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_344932e2655d7b32704be8de9a63bbd8c3369f02.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_345a939a2491166dc520e9a2b9de7e43671e0c2b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_345a939a2491166dc520e9a2b9de7e43671e0c2b.hip new file mode 100644 index 000000000000..adf03ddef034 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_345a939a2491166dc520e9a2b9de7e43671e0c2b.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_345ea796c8d97bfe3b7c9663bf15e2e5e7696235.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_345ea796c8d97bfe3b7c9663bf15e2e5e7696235.hip new file mode 100644 index 000000000000..7f7027e9a8e8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_345ea796c8d97bfe3b7c9663bf15e2e5e7696235.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_34807a8e90bf1cd839f32fd718afa6469c35a4fa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_34807a8e90bf1cd839f32fd718afa6469c35a4fa.hip new file mode 100644 index 000000000000..ddbca1e546cf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_34807a8e90bf1cd839f32fd718afa6469c35a4fa.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_349241529745bf138552f49d9a93db418663ad65.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_349241529745bf138552f49d9a93db418663ad65.hip new file mode 100644 index 000000000000..58b13dfa55dd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_349241529745bf138552f49d9a93db418663ad65.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_34c2db98d8e2e690f499f41cfd5afb831b756f54.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_34c2db98d8e2e690f499f41cfd5afb831b756f54.hip new file mode 100644 index 000000000000..3e4446f52ffa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_34c2db98d8e2e690f499f41cfd5afb831b756f54.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3511c54e6a6f9eec378d8b661121066536195d3a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3511c54e6a6f9eec378d8b661121066536195d3a.hip new file mode 100644 index 000000000000..83b2644097c2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3511c54e6a6f9eec378d8b661121066536195d3a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_351425a006aeeff4d69c8570cb6bf1e1427d2c21.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_351425a006aeeff4d69c8570cb6bf1e1427d2c21.hip new file mode 100644 index 000000000000..98574609d777 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_351425a006aeeff4d69c8570cb6bf1e1427d2c21.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_354121d3bad1d448bd413718fa096f54faa12e95.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_354121d3bad1d448bd413718fa096f54faa12e95.hip new file mode 100644 index 000000000000..5e1fe0ac82ca --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_354121d3bad1d448bd413718fa096f54faa12e95.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_356f83cb96d0313abcdb24955edd4264df72aed7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_356f83cb96d0313abcdb24955edd4264df72aed7.hip new file mode 100644 index 000000000000..9f420d7fdfc7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_356f83cb96d0313abcdb24955edd4264df72aed7.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_357f7e626135cc9176a295f3d1f336a7c3852688.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_357f7e626135cc9176a295f3d1f336a7c3852688.hip new file mode 100644 index 000000000000..500ae2c9a2eb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_357f7e626135cc9176a295f3d1f336a7c3852688.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_358399e756ed5026baf3ab78af17489dc07b9532.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_358399e756ed5026baf3ab78af17489dc07b9532.hip new file mode 100644 index 000000000000..fd0a60eb88e0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_358399e756ed5026baf3ab78af17489dc07b9532.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_358d28c958c0a831a615a4811d13279b18db09c4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_358d28c958c0a831a615a4811d13279b18db09c4.hip new file mode 100644 index 000000000000..ecdcdfe14955 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_358d28c958c0a831a615a4811d13279b18db09c4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3642b78913a853a62dbff8b99d9ae3fa458f461d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3642b78913a853a62dbff8b99d9ae3fa458f461d.hip new file mode 100644 index 000000000000..0a74851481f6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3642b78913a853a62dbff8b99d9ae3fa458f461d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_366662dccf2f650bcd8123c49006c759cd4c0ef6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_366662dccf2f650bcd8123c49006c759cd4c0ef6.hip new file mode 100644 index 000000000000..70e55fe4b940 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_366662dccf2f650bcd8123c49006c759cd4c0ef6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_367e58867c46d96c9bbaa96eaaa9f93595c9e099.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_367e58867c46d96c9bbaa96eaaa9f93595c9e099.hip new file mode 100644 index 000000000000..9a504bf3b41a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_367e58867c46d96c9bbaa96eaaa9f93595c9e099.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_36a0a960541bd8a2dc6741579de685b7c0a5f6d7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_36a0a960541bd8a2dc6741579de685b7c0a5f6d7.hip new file mode 100644 index 000000000000..52bec5174646 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_36a0a960541bd8a2dc6741579de685b7c0a5f6d7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_377b70f54cb2778b5ce3df936b477f775eea8b3c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_377b70f54cb2778b5ce3df936b477f775eea8b3c.hip new file mode 100644 index 000000000000..1aec4fa96c6a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_377b70f54cb2778b5ce3df936b477f775eea8b3c.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_378759ae25465c32960487375828e23c5f1ac869.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_378759ae25465c32960487375828e23c5f1ac869.hip new file mode 100644 index 000000000000..7b20483d49cb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_378759ae25465c32960487375828e23c5f1ac869.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_378bf438642e5d863e31145ada2a0688059aa5d9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_378bf438642e5d863e31145ada2a0688059aa5d9.hip new file mode 100644 index 000000000000..907a5f0d39d1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_378bf438642e5d863e31145ada2a0688059aa5d9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_37ad61bf8427a26775969f8a9166fd0bfb7446b4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_37ad61bf8427a26775969f8a9166fd0bfb7446b4.hip new file mode 100644 index 000000000000..44d320675fc9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_37ad61bf8427a26775969f8a9166fd0bfb7446b4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_37fe04467e87ec2110f60c7aea0cc9bf2ca07481.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_37fe04467e87ec2110f60c7aea0cc9bf2ca07481.hip new file mode 100644 index 000000000000..53602634b9cc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_37fe04467e87ec2110f60c7aea0cc9bf2ca07481.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38010c9bf7341588f071f889b7a0b4dcc4e7a14c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38010c9bf7341588f071f889b7a0b4dcc4e7a14c.hip new file mode 100644 index 000000000000..ed72982e3e9c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38010c9bf7341588f071f889b7a0b4dcc4e7a14c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_381b29d9888365bff0f109d897b508eebfd8a61f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_381b29d9888365bff0f109d897b508eebfd8a61f.hip new file mode 100644 index 000000000000..d4fb08303056 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_381b29d9888365bff0f109d897b508eebfd8a61f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3824e97d5ecba46e06d5ec1a9456c810d80227a3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3824e97d5ecba46e06d5ec1a9456c810d80227a3.hip new file mode 100644 index 000000000000..58548b3c949c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3824e97d5ecba46e06d5ec1a9456c810d80227a3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38273a2f8e6bbb42ba0b0871b6c95abb34531f33.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38273a2f8e6bbb42ba0b0871b6c95abb34531f33.hip new file mode 100644 index 000000000000..5836a4b4a2df --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38273a2f8e6bbb42ba0b0871b6c95abb34531f33.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + false, + false, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38a5ff72f22e0ad040a281e66b1aca0bf3a2aadb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38a5ff72f22e0ad040a281e66b1aca0bf3a2aadb.hip new file mode 100644 index 000000000000..46e36eba6efb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38a5ff72f22e0ad040a281e66b1aca0bf3a2aadb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38abcbeaa4d33d3150f2b0238bb62ebbfe960980.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38abcbeaa4d33d3150f2b0238bb62ebbfe960980.hip new file mode 100644 index 000000000000..1a297aa8e40b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38abcbeaa4d33d3150f2b0238bb62ebbfe960980.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38b94d76503e13c911781169fbc378517332c42e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38b94d76503e13c911781169fbc378517332c42e.hip new file mode 100644 index 000000000000..b8e7a9eb57d1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38b94d76503e13c911781169fbc378517332c42e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38bb367362fe2c4849ded728ec5dd00969ce188f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38bb367362fe2c4849ded728ec5dd00969ce188f.hip new file mode 100644 index 000000000000..0955f50c4880 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38bb367362fe2c4849ded728ec5dd00969ce188f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38e12dad9e3bafe177ed3c27c833825813e18fc3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38e12dad9e3bafe177ed3c27c833825813e18fc3.hip new file mode 100644 index 000000000000..951ef1eb5db4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38e12dad9e3bafe177ed3c27c833825813e18fc3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38f8a89468cf9c8606cf12a930db062a83cd0ea0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38f8a89468cf9c8606cf12a930db062a83cd0ea0.hip new file mode 100644 index 000000000000..872a38cd720b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_38f8a89468cf9c8606cf12a930db062a83cd0ea0.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3937d9dfb68351de2942e32f35e2ca1ce71edfa8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3937d9dfb68351de2942e32f35e2ca1ce71edfa8.hip new file mode 100644 index 000000000000..3ec1276181a1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3937d9dfb68351de2942e32f35e2ca1ce71edfa8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_39422621a00ff79b2f5ec0dafb957c77693537b3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_39422621a00ff79b2f5ec0dafb957c77693537b3.hip new file mode 100644 index 000000000000..2a8d74f47e2b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_39422621a00ff79b2f5ec0dafb957c77693537b3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3967a8807c9451b09227c0f685c18aafeb062fd2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3967a8807c9451b09227c0f685c18aafeb062fd2.hip new file mode 100644 index 000000000000..0ddbbcdaaa53 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3967a8807c9451b09227c0f685c18aafeb062fd2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3992d5df4ba2e999caf6889a852db4e1ba078e65.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3992d5df4ba2e999caf6889a852db4e1ba078e65.hip new file mode 100644 index 000000000000..dbef7fd0994f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3992d5df4ba2e999caf6889a852db4e1ba078e65.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_39d3071347a0c98f3221104036f477aa13bffa4d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_39d3071347a0c98f3221104036f477aa13bffa4d.hip new file mode 100644 index 000000000000..457fc9f0fc9b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_39d3071347a0c98f3221104036f477aa13bffa4d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a1dca5feb864e8981387c2d07e62acef1730aa8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a1dca5feb864e8981387c2d07e62acef1730aa8.hip new file mode 100644 index 000000000000..d937f8f21b43 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a1dca5feb864e8981387c2d07e62acef1730aa8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a2280997eb6f1d091094fc54cecf42b7c9c3a2d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a2280997eb6f1d091094fc54cecf42b7c9c3a2d.hip new file mode 100644 index 000000000000..ff32e3786f24 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a2280997eb6f1d091094fc54cecf42b7c9c3a2d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a2643099365d0903c799585f41dc1a525ac9f9e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a2643099365d0903c799585f41dc1a525ac9f9e.hip new file mode 100644 index 000000000000..5805472bea7c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a2643099365d0903c799585f41dc1a525ac9f9e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a6b9566559ed2b1c85f2bea1c55e72c41dc47bd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a6b9566559ed2b1c85f2bea1c55e72c41dc47bd.hip new file mode 100644 index 000000000000..87d6878c3e75 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3a6b9566559ed2b1c85f2bea1c55e72c41dc47bd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3af86f458fb4dfcceb7db3357fbae0dc15142a15.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3af86f458fb4dfcceb7db3357fbae0dc15142a15.hip new file mode 100644 index 000000000000..c8cead574041 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3af86f458fb4dfcceb7db3357fbae0dc15142a15.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3afbb5ac9048a962a60f48886728220ae6c2aeaf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3afbb5ac9048a962a60f48886728220ae6c2aeaf.hip new file mode 100644 index 000000000000..8dd1c686396f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3afbb5ac9048a962a60f48886728220ae6c2aeaf.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b26eafe76cca8e74e819220b6de1f4279d48e43.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b26eafe76cca8e74e819220b6de1f4279d48e43.hip new file mode 100644 index 000000000000..5a321a93c5fd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b26eafe76cca8e74e819220b6de1f4279d48e43.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b4ecb47f9ebe8c2784976c3e9bbe4834b475cf1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b4ecb47f9ebe8c2784976c3e9bbe4834b475cf1.hip new file mode 100644 index 000000000000..3016cd469f3a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b4ecb47f9ebe8c2784976c3e9bbe4834b475cf1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b508b92f7e123b21658f6e17d624ffa87831fee.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b508b92f7e123b21658f6e17d624ffa87831fee.hip new file mode 100644 index 000000000000..0bb1a69aa48f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b508b92f7e123b21658f6e17d624ffa87831fee.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b5b3c218e4a7b459e54080e24c5b730221eac02.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b5b3c218e4a7b459e54080e24c5b730221eac02.hip new file mode 100644 index 000000000000..721ff0851132 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3b5b3c218e4a7b459e54080e24c5b730221eac02.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3bb129e6dee6848043dd0e8fa812ae80fec4d014.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3bb129e6dee6848043dd0e8fa812ae80fec4d014.hip new file mode 100644 index 000000000000..4cc230a14a7f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3bb129e6dee6848043dd0e8fa812ae80fec4d014.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3bb3b682eab96e4e173affad75b9d8e73f1dd690.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3bb3b682eab96e4e173affad75b9d8e73f1dd690.hip new file mode 100644 index 000000000000..ae316743ab68 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3bb3b682eab96e4e173affad75b9d8e73f1dd690.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3be7cea6df8e6dd56194e1172f28943667f1c4ef.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3be7cea6df8e6dd56194e1172f28943667f1c4ef.hip new file mode 100644 index 000000000000..f9175f9604cd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3be7cea6df8e6dd56194e1172f28943667f1c4ef.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3bed3aaf24c73073c604a3b23bb4b0358b8e3490.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3bed3aaf24c73073c604a3b23bb4b0358b8e3490.hip new file mode 100644 index 000000000000..128a085e485d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3bed3aaf24c73073c604a3b23bb4b0358b8e3490.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c1454ffc1418dac641f63671e947d9f550b1f0c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c1454ffc1418dac641f63671e947d9f550b1f0c.hip new file mode 100644 index 000000000000..5654f69141fe --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c1454ffc1418dac641f63671e947d9f550b1f0c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c38bb80e9880335faaea81985ed5d0e713ecb08.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c38bb80e9880335faaea81985ed5d0e713ecb08.hip new file mode 100644 index 000000000000..5798ef24307c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c38bb80e9880335faaea81985ed5d0e713ecb08.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c3b7e4b8c1efe59f79a15512716fce2282a79a7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c3b7e4b8c1efe59f79a15512716fce2282a79a7.hip new file mode 100644 index 000000000000..fd12889cf474 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c3b7e4b8c1efe59f79a15512716fce2282a79a7.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c64c33870ebc329921cfa3867d58b1857421f65.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c64c33870ebc329921cfa3867d58b1857421f65.hip new file mode 100644 index 000000000000..5115f25ea325 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3c64c33870ebc329921cfa3867d58b1857421f65.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cb0cee09d633b6f70febbba63a1e090522cfb4a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cb0cee09d633b6f70febbba63a1e090522cfb4a.hip new file mode 100644 index 000000000000..dfeef4e22f26 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cb0cee09d633b6f70febbba63a1e090522cfb4a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cce3baac1e3ca03af0c3f4ee4d0158ad1031e9f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cce3baac1e3ca03af0c3f4ee4d0158ad1031e9f.hip new file mode 100644 index 000000000000..1121d7136311 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cce3baac1e3ca03af0c3f4ee4d0158ad1031e9f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3ccf0a9d5a5451da5dbf6075ccea45e4a140550a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3ccf0a9d5a5451da5dbf6075ccea45e4a140550a.hip new file mode 100644 index 000000000000..9e84ec8cd3a6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3ccf0a9d5a5451da5dbf6075ccea45e4a140550a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cd7a9ca49c1149d46f6b05b0fefc41ecaeb6ea1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cd7a9ca49c1149d46f6b05b0fefc41ecaeb6ea1.hip new file mode 100644 index 000000000000..68f833e3b58b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cd7a9ca49c1149d46f6b05b0fefc41ecaeb6ea1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cf45927b6d931e31e2209685d787efa28eed8ba.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cf45927b6d931e31e2209685d787efa28eed8ba.hip new file mode 100644 index 000000000000..8d24261fb728 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3cf45927b6d931e31e2209685d787efa28eed8ba.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d1cea88a2277b87d405025ba256272a1720f88d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d1cea88a2277b87d405025ba256272a1720f88d.hip new file mode 100644 index 000000000000..86d662d13740 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d1cea88a2277b87d405025ba256272a1720f88d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d289100991d4c8c362f64c8f6c4ba395c2f3495.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d289100991d4c8c362f64c8f6c4ba395c2f3495.hip new file mode 100644 index 000000000000..0e52b9cee61e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d289100991d4c8c362f64c8f6c4ba395c2f3495.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d3f3eb2f5eb1f3287879604892b1c230df85f1d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d3f3eb2f5eb1f3287879604892b1c230df85f1d.hip new file mode 100644 index 000000000000..e480475f08e0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d3f3eb2f5eb1f3287879604892b1c230df85f1d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d45624dc6e33c477c73a155500b015b6c010de8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d45624dc6e33c477c73a155500b015b6c010de8.hip new file mode 100644 index 000000000000..2e713305b18c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d45624dc6e33c477c73a155500b015b6c010de8.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d55cb42b0096a8ae338ce100f86e378aa1a04c9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d55cb42b0096a8ae338ce100f86e378aa1a04c9.hip new file mode 100644 index 000000000000..3d4188634ad2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3d55cb42b0096a8ae338ce100f86e378aa1a04c9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3da8c31f6d5bcaacfa4a21aed4d1d3caecb48922.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3da8c31f6d5bcaacfa4a21aed4d1d3caecb48922.hip new file mode 100644 index 000000000000..45354957b434 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3da8c31f6d5bcaacfa4a21aed4d1d3caecb48922.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3dba3cd44f78c950fe7ceaa5f0629dfc607b30f1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3dba3cd44f78c950fe7ceaa5f0629dfc607b30f1.hip new file mode 100644 index 000000000000..e98c243b41b6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3dba3cd44f78c950fe7ceaa5f0629dfc607b30f1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3dff884e176ec7cff86d17c6afe1ddaa4dd6007d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3dff884e176ec7cff86d17c6afe1ddaa4dd6007d.hip new file mode 100644 index 000000000000..2d38d9ce9e66 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3dff884e176ec7cff86d17c6afe1ddaa4dd6007d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e143d88eaa0d9cfea856b2f3a57d1275a656627.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e143d88eaa0d9cfea856b2f3a57d1275a656627.hip new file mode 100644 index 000000000000..e2bdb517eadc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e143d88eaa0d9cfea856b2f3a57d1275a656627.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e2557f206fd81d82a3b9d59113105040beb891f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e2557f206fd81d82a3b9d59113105040beb891f.hip new file mode 100644 index 000000000000..fd6a937fec6b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e2557f206fd81d82a3b9d59113105040beb891f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e562e6c3af28b8478020ce3c3bf73c036001c93.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e562e6c3af28b8478020ce3c3bf73c036001c93.hip new file mode 100644 index 000000000000..cad897f01019 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e562e6c3af28b8478020ce3c3bf73c036001c93.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e61b019e1398a6a3c36143fb84b5ff22c9f4508.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e61b019e1398a6a3c36143fb84b5ff22c9f4508.hip new file mode 100644 index 000000000000..c110e4b7f6d4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e61b019e1398a6a3c36143fb84b5ff22c9f4508.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e839660557dee9d5bcda9b56940ce23236c5f6d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e839660557dee9d5bcda9b56940ce23236c5f6d.hip new file mode 100644 index 000000000000..ae68340237ab --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3e839660557dee9d5bcda9b56940ce23236c5f6d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3eb2ea922daabbba131b90713e06d8caf5f30662.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3eb2ea922daabbba131b90713e06d8caf5f30662.hip new file mode 100644 index 000000000000..c963f4cd6256 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3eb2ea922daabbba131b90713e06d8caf5f30662.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3ecf565a5a1c4a09887c67ac3b9a019dca427ac0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3ecf565a5a1c4a09887c67ac3b9a019dca427ac0.hip new file mode 100644 index 000000000000..875e8ce0f19c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3ecf565a5a1c4a09887c67ac3b9a019dca427ac0.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3f34433b784d1e405ade3378918641372a30bf6b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3f34433b784d1e405ade3378918641372a30bf6b.hip new file mode 100644 index 000000000000..fd0e82c01b22 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3f34433b784d1e405ade3378918641372a30bf6b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3f5e01b4f2ca8ea10898c39d6570bd74e85f46ed.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3f5e01b4f2ca8ea10898c39d6570bd74e85f46ed.hip new file mode 100644 index 000000000000..1ed5d01a6d72 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3f5e01b4f2ca8ea10898c39d6570bd74e85f46ed.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3f7315955f555768f24585a50d75e216c40f062d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3f7315955f555768f24585a50d75e216c40f062d.hip new file mode 100644 index 000000000000..71bc0a9e7c0b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3f7315955f555768f24585a50d75e216c40f062d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3fad30ff0739ab5dede67a96e859f8c474c245f8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3fad30ff0739ab5dede67a96e859f8c474c245f8.hip new file mode 100644 index 000000000000..979c0bfdda3e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3fad30ff0739ab5dede67a96e859f8c474c245f8.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3fcc6893456a559c7d22714116022fc69b372266.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3fcc6893456a559c7d22714116022fc69b372266.hip new file mode 100644 index 000000000000..c8e79b127c37 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_3fcc6893456a559c7d22714116022fc69b372266.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4018b1fcee808b6cccd131418b6ae9e8bf900d8f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4018b1fcee808b6cccd131418b6ae9e8bf900d8f.hip new file mode 100644 index 000000000000..003b14e205e0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4018b1fcee808b6cccd131418b6ae9e8bf900d8f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4018f690b6322588041bb467beabd8a7bc79a2e0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4018f690b6322588041bb467beabd8a7bc79a2e0.hip new file mode 100644 index 000000000000..e703b937148d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4018f690b6322588041bb467beabd8a7bc79a2e0.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + false, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_40357c5e9739eae136a7abf92bc38d3ac94753f8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_40357c5e9739eae136a7abf92bc38d3ac94753f8.hip new file mode 100644 index 000000000000..0c184f43ee74 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_40357c5e9739eae136a7abf92bc38d3ac94753f8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4052ca6a3ec02f6559e4bbf1edde42ad2d127c26.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4052ca6a3ec02f6559e4bbf1edde42ad2d127c26.hip new file mode 100644 index 000000000000..f24e798a3328 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4052ca6a3ec02f6559e4bbf1edde42ad2d127c26.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_405e7efa263223148318ae96bd1929b382e994e1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_405e7efa263223148318ae96bd1929b382e994e1.hip new file mode 100644 index 000000000000..019ebafe5873 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_405e7efa263223148318ae96bd1929b382e994e1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_40aa64439b80ff8dd12498b3e5f6b625da16e285.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_40aa64439b80ff8dd12498b3e5f6b625da16e285.hip new file mode 100644 index 000000000000..2f53b38c1631 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_40aa64439b80ff8dd12498b3e5f6b625da16e285.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_40db688a9189e1c47c300d474df946a248a63303.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_40db688a9189e1c47c300d474df946a248a63303.hip new file mode 100644 index 000000000000..bd1baeb57f5c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_40db688a9189e1c47c300d474df946a248a63303.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4118e3ab290263ed2576feaf22a1944bf2ddcb7a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4118e3ab290263ed2576feaf22a1944bf2ddcb7a.hip new file mode 100644 index 000000000000..18b57f6f1082 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4118e3ab290263ed2576feaf22a1944bf2ddcb7a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_415b183c50dd2663dabe3eb8b780913b778c54ab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_415b183c50dd2663dabe3eb8b780913b778c54ab.hip new file mode 100644 index 000000000000..b662b55b69b0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_415b183c50dd2663dabe3eb8b780913b778c54ab.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4160f6b6d0869740a5a411abd80108f729f810eb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4160f6b6d0869740a5a411abd80108f729f810eb.hip new file mode 100644 index 000000000000..1eb0f4fd3ff3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4160f6b6d0869740a5a411abd80108f729f810eb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_417b1cb14b67dc82f614831550f7deb0895bd7e4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_417b1cb14b67dc82f614831550f7deb0895bd7e4.hip new file mode 100644 index 000000000000..83ef09fb1036 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_417b1cb14b67dc82f614831550f7deb0895bd7e4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_419461cdb5687ebbb7bf0be136071d70420c1619.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_419461cdb5687ebbb7bf0be136071d70420c1619.hip new file mode 100644 index 000000000000..dd3ded85bb1d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_419461cdb5687ebbb7bf0be136071d70420c1619.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_41b68458076e6cb129d3ec793e95b91430a0c8a1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_41b68458076e6cb129d3ec793e95b91430a0c8a1.hip new file mode 100644 index 000000000000..306bfbcf2519 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_41b68458076e6cb129d3ec793e95b91430a0c8a1.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_41db3f29d1940e59dadc357c040ea37a6ff208d9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_41db3f29d1940e59dadc357c040ea37a6ff208d9.hip new file mode 100644 index 000000000000..3fac933775f3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_41db3f29d1940e59dadc357c040ea37a6ff208d9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4217a48a1677bd26cd48e512f1fc8830a8a551b8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4217a48a1677bd26cd48e512f1fc8830a8a551b8.hip new file mode 100644 index 000000000000..47c15f956e7f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4217a48a1677bd26cd48e512f1fc8830a8a551b8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_428ce4e14cf94b284ffa735fe03d923cc74c9fe0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_428ce4e14cf94b284ffa735fe03d923cc74c9fe0.hip new file mode 100644 index 000000000000..a0d0d02b9997 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_428ce4e14cf94b284ffa735fe03d923cc74c9fe0.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_429b82a27571ac91e3631cbdb7e0a58155abf962.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_429b82a27571ac91e3631cbdb7e0a58155abf962.hip new file mode 100644 index 000000000000..b376dd808e60 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_429b82a27571ac91e3631cbdb7e0a58155abf962.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_42e2326066c91452335eac05f25a6311376bd9e5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_42e2326066c91452335eac05f25a6311376bd9e5.hip new file mode 100644 index 000000000000..c479c02017d1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_42e2326066c91452335eac05f25a6311376bd9e5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4306c6c37cf472ad262f53941611b5e60072bdf6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4306c6c37cf472ad262f53941611b5e60072bdf6.hip new file mode 100644 index 000000000000..98aac8b5ab7f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4306c6c37cf472ad262f53941611b5e60072bdf6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4347e039c003489dd528faf5d710e687321a3fd7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4347e039c003489dd528faf5d710e687321a3fd7.hip new file mode 100644 index 000000000000..b2f58522f169 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4347e039c003489dd528faf5d710e687321a3fd7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4356b3a2ff49f72b91a6b9c215df285f2798ad47.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4356b3a2ff49f72b91a6b9c215df285f2798ad47.hip new file mode 100644 index 000000000000..d419d50e803a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4356b3a2ff49f72b91a6b9c215df285f2798ad47.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4377ac04be3a6cbdbfbe57612a469412812fb5b5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4377ac04be3a6cbdbfbe57612a469412812fb5b5.hip new file mode 100644 index 000000000000..3e326ee37af8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4377ac04be3a6cbdbfbe57612a469412812fb5b5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_438e3565f4c720e6c9691b0d33c1392936e2e7ae.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_438e3565f4c720e6c9691b0d33c1392936e2e7ae.hip new file mode 100644 index 000000000000..a3c0807ec0b3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_438e3565f4c720e6c9691b0d33c1392936e2e7ae.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4395d3c96b3f4556b9765fd0a3b5701b2fb10948.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4395d3c96b3f4556b9765fd0a3b5701b2fb10948.hip new file mode 100644 index 000000000000..1e96c426d210 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4395d3c96b3f4556b9765fd0a3b5701b2fb10948.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + true, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_43e7c78e8f65be35e2753a0ad5123118555c56b2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_43e7c78e8f65be35e2753a0ad5123118555c56b2.hip new file mode 100644 index 000000000000..2f1461f5712e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_43e7c78e8f65be35e2753a0ad5123118555c56b2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_43f2156a04b18bab55af60e9357f28d8a4604e8e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_43f2156a04b18bab55af60e9357f28d8a4604e8e.hip new file mode 100644 index 000000000000..fcc0d7ac7bc6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_43f2156a04b18bab55af60e9357f28d8a4604e8e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4409f2a7deb027e864afdfc9975d3ab93c5dcc9a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4409f2a7deb027e864afdfc9975d3ab93c5dcc9a.hip new file mode 100644 index 000000000000..62c4fdb191b7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4409f2a7deb027e864afdfc9975d3ab93c5dcc9a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4432c5214c4d40c54ca2d02f0d4785c6d6902370.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4432c5214c4d40c54ca2d02f0d4785c6d6902370.hip new file mode 100644 index 000000000000..7aea6d7a4c24 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4432c5214c4d40c54ca2d02f0d4785c6d6902370.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44462715ed5f192532760d6f4c66ff9d4e20e254.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44462715ed5f192532760d6f4c66ff9d4e20e254.hip new file mode 100644 index 000000000000..6f48793bb268 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44462715ed5f192532760d6f4c66ff9d4e20e254.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44564dddf8b492d80be54854abb8d1d831e42679.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44564dddf8b492d80be54854abb8d1d831e42679.hip new file mode 100644 index 000000000000..ab9446438c8e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44564dddf8b492d80be54854abb8d1d831e42679.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_445cd8fa559588f4264ce6192f2de3e3065365ea.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_445cd8fa559588f4264ce6192f2de3e3065365ea.hip new file mode 100644 index 000000000000..1968b601b683 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_445cd8fa559588f4264ce6192f2de3e3065365ea.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_445e28a8a51cd435130ded2abc9fc606e522c713.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_445e28a8a51cd435130ded2abc9fc606e522c713.hip new file mode 100644 index 000000000000..dafbd57a8749 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_445e28a8a51cd435130ded2abc9fc606e522c713.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4462b192a64efb60d5484798526278ac7a0fb9fa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4462b192a64efb60d5484798526278ac7a0fb9fa.hip new file mode 100644 index 000000000000..70e353ec443c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4462b192a64efb60d5484798526278ac7a0fb9fa.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4466b6c6b2ec3acb40ac1cda432efa1e4e62d9d9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4466b6c6b2ec3acb40ac1cda432efa1e4e62d9d9.hip new file mode 100644 index 000000000000..62f5cfe22528 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4466b6c6b2ec3acb40ac1cda432efa1e4e62d9d9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44690e48f30657b0fcfa26fb3b9af3ef76e792e3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44690e48f30657b0fcfa26fb3b9af3ef76e792e3.hip new file mode 100644 index 000000000000..b6b4283da67a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44690e48f30657b0fcfa26fb3b9af3ef76e792e3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44c181996532676f2140fd026707135144e9d37b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44c181996532676f2140fd026707135144e9d37b.hip new file mode 100644 index 000000000000..8f05e41b7406 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44c181996532676f2140fd026707135144e9d37b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44cc95831c347212021c0bab7b43acd7daabce42.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44cc95831c347212021c0bab7b43acd7daabce42.hip new file mode 100644 index 000000000000..a14760b5823c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44cc95831c347212021c0bab7b43acd7daabce42.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44d82b58fdc3e5b7a7c20490ce7f5acce4e6ec79.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44d82b58fdc3e5b7a7c20490ce7f5acce4e6ec79.hip new file mode 100644 index 000000000000..321626f2d284 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_44d82b58fdc3e5b7a7c20490ce7f5acce4e6ec79.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_451fbbdc2dcf2ec81efce34673ee6c425cc16ca2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_451fbbdc2dcf2ec81efce34673ee6c425cc16ca2.hip new file mode 100644 index 000000000000..cd9499e0354c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_451fbbdc2dcf2ec81efce34673ee6c425cc16ca2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4568af1b2f104664fd05d21ad789aed39ecfa42b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4568af1b2f104664fd05d21ad789aed39ecfa42b.hip new file mode 100644 index 000000000000..5ff6782c43c8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4568af1b2f104664fd05d21ad789aed39ecfa42b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_457eaffbff3c58183a656687010daa2c16cfc26e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_457eaffbff3c58183a656687010daa2c16cfc26e.hip new file mode 100644 index 000000000000..3c2f785595f7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_457eaffbff3c58183a656687010daa2c16cfc26e.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_458d708d13577f2b92e6d5adfe952a87e0cf7be5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_458d708d13577f2b92e6d5adfe952a87e0cf7be5.hip new file mode 100644 index 000000000000..978f29329e4e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_458d708d13577f2b92e6d5adfe952a87e0cf7be5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_459c8fb6028991321b09a990c2188d854d940268.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_459c8fb6028991321b09a990c2188d854d940268.hip new file mode 100644 index 000000000000..a7206f643ddf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_459c8fb6028991321b09a990c2188d854d940268.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_459ea3713aef9b916e1b38a882a45012930924d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_459ea3713aef9b916e1b38a882a45012930924d3.hip new file mode 100644 index 000000000000..ade3464bbb2d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_459ea3713aef9b916e1b38a882a45012930924d3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_45b9871c220c0065d74bffeed4021d0304a9625c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_45b9871c220c0065d74bffeed4021d0304a9625c.hip new file mode 100644 index 000000000000..3cd9c78e3e9f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_45b9871c220c0065d74bffeed4021d0304a9625c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_45f4363f50af1e7ccd24751d5f5b181bf32c604f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_45f4363f50af1e7ccd24751d5f5b181bf32c604f.hip new file mode 100644 index 000000000000..ff6b89fae8be --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_45f4363f50af1e7ccd24751d5f5b181bf32c604f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4601680af41c8738089ff377147e0547dcad114d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4601680af41c8738089ff377147e0547dcad114d.hip new file mode 100644 index 000000000000..9f6fd3567890 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4601680af41c8738089ff377147e0547dcad114d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_461737a13e24009bf1a5a4b780175043a9f2e33e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_461737a13e24009bf1a5a4b780175043a9f2e33e.hip new file mode 100644 index 000000000000..33bfc80603c6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_461737a13e24009bf1a5a4b780175043a9f2e33e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4666db0ff7b035e54f2c0e59acedc2131b722a55.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4666db0ff7b035e54f2c0e59acedc2131b722a55.hip new file mode 100644 index 000000000000..9d05594cf218 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4666db0ff7b035e54f2c0e59acedc2131b722a55.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_468a5f057fd5cef2df5f919f5102f47e86901e3b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_468a5f057fd5cef2df5f919f5102f47e86901e3b.hip new file mode 100644 index 000000000000..a78256f15483 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_468a5f057fd5cef2df5f919f5102f47e86901e3b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_474fe2d739eca8c93fdcb2c105d4154cee6ca1c1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_474fe2d739eca8c93fdcb2c105d4154cee6ca1c1.hip new file mode 100644 index 000000000000..220a5a4f2ce4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_474fe2d739eca8c93fdcb2c105d4154cee6ca1c1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_47548aa042c69bb9c59a8bf706b44028aaa41830.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_47548aa042c69bb9c59a8bf706b44028aaa41830.hip new file mode 100644 index 000000000000..309681f33ce3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_47548aa042c69bb9c59a8bf706b44028aaa41830.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_47f3ced9b5ddb0dfee8ed5e7df8eca0bbe273047.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_47f3ced9b5ddb0dfee8ed5e7df8eca0bbe273047.hip new file mode 100644 index 000000000000..8b2a1e2a3b0c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_47f3ced9b5ddb0dfee8ed5e7df8eca0bbe273047.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_47fe73f04cef91cd2a0682e905483968ff80eadb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_47fe73f04cef91cd2a0682e905483968ff80eadb.hip new file mode 100644 index 000000000000..ae19c51db82e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_47fe73f04cef91cd2a0682e905483968ff80eadb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_481415463f0316ebe25ff2fda47c68cc54db3359.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_481415463f0316ebe25ff2fda47c68cc54db3359.hip new file mode 100644 index 000000000000..c7c7457b8006 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_481415463f0316ebe25ff2fda47c68cc54db3359.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4824e1f8cda50f80988857611da766685da94494.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4824e1f8cda50f80988857611da766685da94494.hip new file mode 100644 index 000000000000..b17bad6b12e5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4824e1f8cda50f80988857611da766685da94494.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48280c91d7cd8712fd533e246a6b0f758834abc9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48280c91d7cd8712fd533e246a6b0f758834abc9.hip new file mode 100644 index 000000000000..c94863ece048 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48280c91d7cd8712fd533e246a6b0f758834abc9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_482e34930d11ff493007b1613993e01acc1af78d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_482e34930d11ff493007b1613993e01acc1af78d.hip new file mode 100644 index 000000000000..56917606f085 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_482e34930d11ff493007b1613993e01acc1af78d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48300e0aeabe337785d4c7b41796ce65df6cc42a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48300e0aeabe337785d4c7b41796ce65df6cc42a.hip new file mode 100644 index 000000000000..a23d30926d57 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48300e0aeabe337785d4c7b41796ce65df6cc42a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_483eaea4096c8f5bee16a64860432f0634a253d8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_483eaea4096c8f5bee16a64860432f0634a253d8.hip new file mode 100644 index 000000000000..a29e1beb4fc0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_483eaea4096c8f5bee16a64860432f0634a253d8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48435e5dd23e49e19dd313f9891ffec800ce74c2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48435e5dd23e49e19dd313f9891ffec800ce74c2.hip new file mode 100644 index 000000000000..0301c13f7dd1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48435e5dd23e49e19dd313f9891ffec800ce74c2.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_486f6c7c7655c34b7b9973ff357b0813f0a3fd7c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_486f6c7c7655c34b7b9973ff357b0813f0a3fd7c.hip new file mode 100644 index 000000000000..2a189a65a945 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_486f6c7c7655c34b7b9973ff357b0813f0a3fd7c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_487724686efd35731e5335efa949486c93ae26e3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_487724686efd35731e5335efa949486c93ae26e3.hip new file mode 100644 index 000000000000..ad8f080b4955 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_487724686efd35731e5335efa949486c93ae26e3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_489e7be0f85656d012a6451b65f6c1d2613b187d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_489e7be0f85656d012a6451b65f6c1d2613b187d.hip new file mode 100644 index 000000000000..9cbb835abd16 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_489e7be0f85656d012a6451b65f6c1d2613b187d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48ae3af78583258c4b13c11a442022e0e058bb85.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48ae3af78583258c4b13c11a442022e0e058bb85.hip new file mode 100644 index 000000000000..0eff4b86984a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48ae3af78583258c4b13c11a442022e0e058bb85.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48d7d145f96aa8958a9208d0c8887742a8c834fd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48d7d145f96aa8958a9208d0c8887742a8c834fd.hip new file mode 100644 index 000000000000..f4bb78b05111 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48d7d145f96aa8958a9208d0c8887742a8c834fd.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48e9e858abf6f77489f3fadc4ee81edacd26705a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48e9e858abf6f77489f3fadc4ee81edacd26705a.hip new file mode 100644 index 000000000000..3cb597dcfcd4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_48e9e858abf6f77489f3fadc4ee81edacd26705a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4904c5910a2d0595b39a3f87652a9d1ef4fcbe80.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4904c5910a2d0595b39a3f87652a9d1ef4fcbe80.hip new file mode 100644 index 000000000000..e023c01ef005 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4904c5910a2d0595b39a3f87652a9d1ef4fcbe80.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_490a68220a7b621ae9817d7b77f55de239b0a4f3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_490a68220a7b621ae9817d7b77f55de239b0a4f3.hip new file mode 100644 index 000000000000..006da14cc912 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_490a68220a7b621ae9817d7b77f55de239b0a4f3.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4911bdd71351610d55916d452495e599960d0a41.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4911bdd71351610d55916d452495e599960d0a41.hip new file mode 100644 index 000000000000..83d2fd0dc854 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4911bdd71351610d55916d452495e599960d0a41.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_492fbc418e829f89bcb8d93f8afd2869dd8dfccc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_492fbc418e829f89bcb8d93f8afd2869dd8dfccc.hip new file mode 100644 index 000000000000..f3e0cbb7a064 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_492fbc418e829f89bcb8d93f8afd2869dd8dfccc.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + true, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_49d4c005d723cdab9fbc307933c1257d114b539e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_49d4c005d723cdab9fbc307933c1257d114b539e.hip new file mode 100644 index 000000000000..d34f2a8e97d8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_49d4c005d723cdab9fbc307933c1257d114b539e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_49f5017cc0f5c8c8dc71492e7765cf729c1f225c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_49f5017cc0f5c8c8dc71492e7765cf729c1f225c.hip new file mode 100644 index 000000000000..cfac59538b5b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_49f5017cc0f5c8c8dc71492e7765cf729c1f225c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a06b5b153ea6e8b1e20d9aad9d4633333fd98f5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a06b5b153ea6e8b1e20d9aad9d4633333fd98f5.hip new file mode 100644 index 000000000000..19df50232d28 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a06b5b153ea6e8b1e20d9aad9d4633333fd98f5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a2e6b05e7e4de2cb23d815f8b2c8adf22131c0c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a2e6b05e7e4de2cb23d815f8b2c8adf22131c0c.hip new file mode 100644 index 000000000000..7c7e7201db31 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a2e6b05e7e4de2cb23d815f8b2c8adf22131c0c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a4a00bd6ea27ff20a2903d619e1361b5e27672a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a4a00bd6ea27ff20a2903d619e1361b5e27672a.hip new file mode 100644 index 000000000000..ddd06638883e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a4a00bd6ea27ff20a2903d619e1361b5e27672a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a5dbf601de5754c03a03a1a42395dc0766fb8ac.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a5dbf601de5754c03a03a1a42395dc0766fb8ac.hip new file mode 100644 index 000000000000..694fa722d3fa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a5dbf601de5754c03a03a1a42395dc0766fb8ac.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a9f3da698a6103caf25d785928dd9f814ac27b4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a9f3da698a6103caf25d785928dd9f814ac27b4.hip new file mode 100644 index 000000000000..30c20f805934 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4a9f3da698a6103caf25d785928dd9f814ac27b4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ab5d6e8fbfd92e9f7e47bda5cfbb0d4162a6319.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ab5d6e8fbfd92e9f7e47bda5cfbb0d4162a6319.hip new file mode 100644 index 000000000000..74f703b455bd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ab5d6e8fbfd92e9f7e47bda5cfbb0d4162a6319.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4afd02981f92fbef6277c1985cc479c12bae9239.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4afd02981f92fbef6277c1985cc479c12bae9239.hip new file mode 100644 index 000000000000..8baadf075053 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4afd02981f92fbef6277c1985cc479c12bae9239.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b1eaca3c37a82d19f8dc91f06764170069ca3af.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b1eaca3c37a82d19f8dc91f06764170069ca3af.hip new file mode 100644 index 000000000000..a75f6965bdfa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b1eaca3c37a82d19f8dc91f06764170069ca3af.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b2e7f96b095ebfb66ecc7a75752fba2a63e4f37.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b2e7f96b095ebfb66ecc7a75752fba2a63e4f37.hip new file mode 100644 index 000000000000..948e9dea9c70 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b2e7f96b095ebfb66ecc7a75752fba2a63e4f37.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b30f472f00bec9da0564ddc40e07112b5f9a117.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b30f472f00bec9da0564ddc40e07112b5f9a117.hip new file mode 100644 index 000000000000..0bc5de1e5a63 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b30f472f00bec9da0564ddc40e07112b5f9a117.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b45948f2795293e72530b02669c4f549608ea7f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b45948f2795293e72530b02669c4f549608ea7f.hip new file mode 100644 index 000000000000..329cee50fdde --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b45948f2795293e72530b02669c4f549608ea7f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b4c03c916393d6be7c5181369ebcef949eaa763.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b4c03c916393d6be7c5181369ebcef949eaa763.hip new file mode 100644 index 000000000000..29af28afdda9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b4c03c916393d6be7c5181369ebcef949eaa763.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b68e4d00295b294320b94bc777d7d34609127e0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b68e4d00295b294320b94bc777d7d34609127e0.hip new file mode 100644 index 000000000000..5ca0f423c56c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b68e4d00295b294320b94bc777d7d34609127e0.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b7393d55600c9892558248f4131fc06a6cf3309.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b7393d55600c9892558248f4131fc06a6cf3309.hip new file mode 100644 index 000000000000..65822931820c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b7393d55600c9892558248f4131fc06a6cf3309.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b74439f42140cdda9bb0f78d995d741212a35f4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b74439f42140cdda9bb0f78d995d741212a35f4.hip new file mode 100644 index 000000000000..487dce37eebc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b74439f42140cdda9bb0f78d995d741212a35f4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b76e5dce9af523422782dd25d8dcf6f25edc68f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b76e5dce9af523422782dd25d8dcf6f25edc68f.hip new file mode 100644 index 000000000000..d373ae30a1c0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4b76e5dce9af523422782dd25d8dcf6f25edc68f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4baf664bfdf070362bcc91af77d1bc406f744351.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4baf664bfdf070362bcc91af77d1bc406f744351.hip new file mode 100644 index 000000000000..3a7fcef5e0ee --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4baf664bfdf070362bcc91af77d1bc406f744351.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4bc48576f285325345fa1205e5e7e01787b74f71.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4bc48576f285325345fa1205e5e7e01787b74f71.hip new file mode 100644 index 000000000000..3a7a5e5295c6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4bc48576f285325345fa1205e5e7e01787b74f71.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4bd4d46397a3749646b232b306688e52b8c6e584.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4bd4d46397a3749646b232b306688e52b8c6e584.hip new file mode 100644 index 000000000000..8ba7618b58c8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4bd4d46397a3749646b232b306688e52b8c6e584.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4be4a98f150f3f9ab6f03b5fd0968c5454565c9a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4be4a98f150f3f9ab6f03b5fd0968c5454565c9a.hip new file mode 100644 index 000000000000..4ee9df9540eb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4be4a98f150f3f9ab6f03b5fd0968c5454565c9a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4beca56234ff6fb4f23b9b24822887fd9a3d0df9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4beca56234ff6fb4f23b9b24822887fd9a3d0df9.hip new file mode 100644 index 000000000000..0a39028591c1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4beca56234ff6fb4f23b9b24822887fd9a3d0df9.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4bef4d120e71bfcfe61d67aa44d24ceb907c2b9e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4bef4d120e71bfcfe61d67aa44d24ceb907c2b9e.hip new file mode 100644 index 000000000000..10933b2ba7c9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4bef4d120e71bfcfe61d67aa44d24ceb907c2b9e.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4c0c50a1fac82d47dff2357ee3ddbfa0b2c8d487.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4c0c50a1fac82d47dff2357ee3ddbfa0b2c8d487.hip new file mode 100644 index 000000000000..7536bd3bba03 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4c0c50a1fac82d47dff2357ee3ddbfa0b2c8d487.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + false, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4c69d06e3f32e3b6d28d3e54ad764b472741c193.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4c69d06e3f32e3b6d28d3e54ad764b472741c193.hip new file mode 100644 index 000000000000..c2cbb592946a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4c69d06e3f32e3b6d28d3e54ad764b472741c193.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4c8720923c3452e3aebd7b9c1b4b23f0c35d7e4f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4c8720923c3452e3aebd7b9c1b4b23f0c35d7e4f.hip new file mode 100644 index 000000000000..2c5ad3e47e09 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4c8720923c3452e3aebd7b9c1b4b23f0c35d7e4f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4cabdafad0bf803223ba5e8f474cd59233dc48cb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4cabdafad0bf803223ba5e8f474cd59233dc48cb.hip new file mode 100644 index 000000000000..ae7425e5dbf5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4cabdafad0bf803223ba5e8f474cd59233dc48cb.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + false, + false, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4cb1861e31df98bdfd731efc3d335055090d83af.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4cb1861e31df98bdfd731efc3d335055090d83af.hip new file mode 100644 index 000000000000..cc3c001e9e9b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4cb1861e31df98bdfd731efc3d335055090d83af.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4cd3de43cc1f7588d62a10362f59d113ee818846.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4cd3de43cc1f7588d62a10362f59d113ee818846.hip new file mode 100644 index 000000000000..78cb9146355a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4cd3de43cc1f7588d62a10362f59d113ee818846.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ce03571f1d2779bdeaf0a6a2d617e236d191c11.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ce03571f1d2779bdeaf0a6a2d617e236d191c11.hip new file mode 100644 index 000000000000..207b155c5944 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ce03571f1d2779bdeaf0a6a2d617e236d191c11.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ce671f5defd76ca08614a7a1f184c36c0f1e2ab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ce671f5defd76ca08614a7a1f184c36c0f1e2ab.hip new file mode 100644 index 000000000000..9bc545c3cd40 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ce671f5defd76ca08614a7a1f184c36c0f1e2ab.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d3b1ae63e127b6e6afe39e354d4995afc5faeaf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d3b1ae63e127b6e6afe39e354d4995afc5faeaf.hip new file mode 100644 index 000000000000..73e33032d349 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d3b1ae63e127b6e6afe39e354d4995afc5faeaf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d5f3cf0f78f73df79665c26b20b0805615e1b04.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d5f3cf0f78f73df79665c26b20b0805615e1b04.hip new file mode 100644 index 000000000000..ab6f84bf934b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d5f3cf0f78f73df79665c26b20b0805615e1b04.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d65e58c9f147498ed04dd51fe1393770603a6d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d65e58c9f147498ed04dd51fe1393770603a6d3.hip new file mode 100644 index 000000000000..107834d38364 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d65e58c9f147498ed04dd51fe1393770603a6d3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d7dc0f356b630179916f8fc2041b7f1402b46df.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d7dc0f356b630179916f8fc2041b7f1402b46df.hip new file mode 100644 index 000000000000..0f4093125b82 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4d7dc0f356b630179916f8fc2041b7f1402b46df.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4da9e9b7277bc90518ab92860bef2097ba96d982.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4da9e9b7277bc90518ab92860bef2097ba96d982.hip new file mode 100644 index 000000000000..965cd600b8a5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4da9e9b7277bc90518ab92860bef2097ba96d982.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4db2e63cfebcf84043f79be0321708cd159c62b9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4db2e63cfebcf84043f79be0321708cd159c62b9.hip new file mode 100644 index 000000000000..e0d9f465a9c8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4db2e63cfebcf84043f79be0321708cd159c62b9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4dbdd9c3f496a27bde68cf86374999ff2dd53505.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4dbdd9c3f496a27bde68cf86374999ff2dd53505.hip new file mode 100644 index 000000000000..280f3877a199 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4dbdd9c3f496a27bde68cf86374999ff2dd53505.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4dc87b7d385e7b092e4706c464217b004fd8a6a4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4dc87b7d385e7b092e4706c464217b004fd8a6a4.hip new file mode 100644 index 000000000000..e03645b96d01 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4dc87b7d385e7b092e4706c464217b004fd8a6a4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4dde56efe17f4fd36a11cc959320a5e43f1dc232.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4dde56efe17f4fd36a11cc959320a5e43f1dc232.hip new file mode 100644 index 000000000000..8f794db20e2f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4dde56efe17f4fd36a11cc959320a5e43f1dc232.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e0a88ccef04e81b8c684b695f7cb4310e448915.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e0a88ccef04e81b8c684b695f7cb4310e448915.hip new file mode 100644 index 000000000000..090f8fb619d0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e0a88ccef04e81b8c684b695f7cb4310e448915.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e15e4f16de26068cba30ef12fc29332d45e460e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e15e4f16de26068cba30ef12fc29332d45e460e.hip new file mode 100644 index 000000000000..869bf7f26273 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e15e4f16de26068cba30ef12fc29332d45e460e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e47f8fa40332c6ed12d9971e0b539049a871c34.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e47f8fa40332c6ed12d9971e0b539049a871c34.hip new file mode 100644 index 000000000000..f09735bdb22c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e47f8fa40332c6ed12d9971e0b539049a871c34.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e760de14b71a41882ec4a2c7362565af36d1a5d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e760de14b71a41882ec4a2c7362565af36d1a5d.hip new file mode 100644 index 000000000000..b3cf01ae4782 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e760de14b71a41882ec4a2c7362565af36d1a5d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e79dce18e49ffe024fe4cd0693ad3399f5edaee.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e79dce18e49ffe024fe4cd0693ad3399f5edaee.hip new file mode 100644 index 000000000000..a3e5a78f9709 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e79dce18e49ffe024fe4cd0693ad3399f5edaee.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e9a933b916285d9580a76df543cfafc88a536cb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e9a933b916285d9580a76df543cfafc88a536cb.hip new file mode 100644 index 000000000000..7c4370995ee3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4e9a933b916285d9580a76df543cfafc88a536cb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ec2075f394acfb14fae7b1ef4304fd9b654ba0d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ec2075f394acfb14fae7b1ef4304fd9b654ba0d.hip new file mode 100644 index 000000000000..e90dcb85f350 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ec2075f394acfb14fae7b1ef4304fd9b654ba0d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ed6da5357b67cc28aee4afa9523adaf055c4e32.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ed6da5357b67cc28aee4afa9523adaf055c4e32.hip new file mode 100644 index 000000000000..17bcce9dfb4a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ed6da5357b67cc28aee4afa9523adaf055c4e32.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ef35d82ceb4af2e07719c16109c6d72eaedce67.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ef35d82ceb4af2e07719c16109c6d72eaedce67.hip new file mode 100644 index 000000000000..a8c94b4bffdf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ef35d82ceb4af2e07719c16109c6d72eaedce67.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f0aded9d1baec3125ce8e176248cb146ca580fa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f0aded9d1baec3125ce8e176248cb146ca580fa.hip new file mode 100644 index 000000000000..b28531693aba --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f0aded9d1baec3125ce8e176248cb146ca580fa.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f1e1c969b57659e7e1367ac9ba10ed5ef5b69a9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f1e1c969b57659e7e1367ac9ba10ed5ef5b69a9.hip new file mode 100644 index 000000000000..ea856e55da7a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f1e1c969b57659e7e1367ac9ba10ed5ef5b69a9.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f44435491aa68acb3217b0e693232c67641a2db.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f44435491aa68acb3217b0e693232c67641a2db.hip new file mode 100644 index 000000000000..c1738ae45e55 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f44435491aa68acb3217b0e693232c67641a2db.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f4a5d56721bb1a1332a65882132a8c5763932ec.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f4a5d56721bb1a1332a65882132a8c5763932ec.hip new file mode 100644 index 000000000000..3ba1ce44f266 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f4a5d56721bb1a1332a65882132a8c5763932ec.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f6243c6850c0a2d2b7bf1476e12f95f187257b6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f6243c6850c0a2d2b7bf1476e12f95f187257b6.hip new file mode 100644 index 000000000000..2a28bfc26dec --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4f6243c6850c0a2d2b7bf1476e12f95f187257b6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fa4d21931b9afcbd70b1567995d3eeb6f9308aa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fa4d21931b9afcbd70b1567995d3eeb6f9308aa.hip new file mode 100644 index 000000000000..cd233c9b3bba --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fa4d21931b9afcbd70b1567995d3eeb6f9308aa.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fa883a36a76edb276a66c5d779294f170d6d4b7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fa883a36a76edb276a66c5d779294f170d6d4b7.hip new file mode 100644 index 000000000000..09ffdabfe312 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fa883a36a76edb276a66c5d779294f170d6d4b7.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fd34faa8b168e2ac7862641229e6146d3e28aee.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fd34faa8b168e2ac7862641229e6146d3e28aee.hip new file mode 100644 index 000000000000..e633af5ce3e9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fd34faa8b168e2ac7862641229e6146d3e28aee.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fe530cbf6363a8f08a94728e45e88ecde299e7b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fe530cbf6363a8f08a94728e45e88ecde299e7b.hip new file mode 100644 index 000000000000..4868b21c49fd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4fe530cbf6363a8f08a94728e45e88ecde299e7b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ff20bafbf156fe8fb80bdd84a5d2f3a4a944c1a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ff20bafbf156fe8fb80bdd84a5d2f3a4a944c1a.hip new file mode 100644 index 000000000000..fc5da06f3e1f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_4ff20bafbf156fe8fb80bdd84a5d2f3a4a944c1a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_501dcf3213efd214cc2ce8c9ba0027f991d241b4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_501dcf3213efd214cc2ce8c9ba0027f991d241b4.hip new file mode 100644 index 000000000000..ba839ff1f706 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_501dcf3213efd214cc2ce8c9ba0027f991d241b4.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5052b2318dbb78b1a82ef03666a35a623f44481b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5052b2318dbb78b1a82ef03666a35a623f44481b.hip new file mode 100644 index 000000000000..7e68e05bdd1d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5052b2318dbb78b1a82ef03666a35a623f44481b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5093976cb7b32a8bd28ce92fc13af00a3e21f737.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5093976cb7b32a8bd28ce92fc13af00a3e21f737.hip new file mode 100644 index 000000000000..dcd91d230a59 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5093976cb7b32a8bd28ce92fc13af00a3e21f737.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50e59bd079f4d205b613056f975fd2b4e372ab10.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50e59bd079f4d205b613056f975fd2b4e372ab10.hip new file mode 100644 index 000000000000..40d7912b5480 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50e59bd079f4d205b613056f975fd2b4e372ab10.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50e7b11019fc2299d70869253877319b03388244.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50e7b11019fc2299d70869253877319b03388244.hip new file mode 100644 index 000000000000..60989e7b50ec --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50e7b11019fc2299d70869253877319b03388244.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50f887556a3540609649744957651ca667b91774.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50f887556a3540609649744957651ca667b91774.hip new file mode 100644 index 000000000000..f984328259b2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50f887556a3540609649744957651ca667b91774.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50f915b4d9bd18a3c25a85917392ea4a5e88b349.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50f915b4d9bd18a3c25a85917392ea4a5e88b349.hip new file mode 100644 index 000000000000..af3ba52165ad --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_50f915b4d9bd18a3c25a85917392ea4a5e88b349.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_515128c6978449b33ce0c35b02a9e9aaad65ef7a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_515128c6978449b33ce0c35b02a9e9aaad65ef7a.hip new file mode 100644 index 000000000000..c7c98a48ed2f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_515128c6978449b33ce0c35b02a9e9aaad65ef7a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_522a2a9435103ed405dc1500d31652f1d431a49d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_522a2a9435103ed405dc1500d31652f1d431a49d.hip new file mode 100644 index 000000000000..dec8c583648d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_522a2a9435103ed405dc1500d31652f1d431a49d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_523e5bf45ec5008aa3aba4773e68a78e122b2fe7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_523e5bf45ec5008aa3aba4773e68a78e122b2fe7.hip new file mode 100644 index 000000000000..a940c486f419 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_523e5bf45ec5008aa3aba4773e68a78e122b2fe7.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_52688999141a72e61322140db29043ef9f7fbc3d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_52688999141a72e61322140db29043ef9f7fbc3d.hip new file mode 100644 index 000000000000..ab8b6d3cad2f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_52688999141a72e61322140db29043ef9f7fbc3d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_526c89b7a04758b4badbf9695b316f877b8bb053.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_526c89b7a04758b4badbf9695b316f877b8bb053.hip new file mode 100644 index 000000000000..d64de6756473 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_526c89b7a04758b4badbf9695b316f877b8bb053.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_528db08068589c6e4c096054d26a2e5be63285b6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_528db08068589c6e4c096054d26a2e5be63285b6.hip new file mode 100644 index 000000000000..2cc28aa9544e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_528db08068589c6e4c096054d26a2e5be63285b6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_52a89981a05963efcea7ba5c1e967638beeebbbb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_52a89981a05963efcea7ba5c1e967638beeebbbb.hip new file mode 100644 index 000000000000..d74acfbc2d0f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_52a89981a05963efcea7ba5c1e967638beeebbbb.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_52a8a323414448c50571a334f29bc0a38919b61d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_52a8a323414448c50571a334f29bc0a38919b61d.hip new file mode 100644 index 000000000000..f0fd94abdbfe --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_52a8a323414448c50571a334f29bc0a38919b61d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_532a6ffd8a21d3e98342fd401f0247f62ca4e038.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_532a6ffd8a21d3e98342fd401f0247f62ca4e038.hip new file mode 100644 index 000000000000..d95bdc8c5daa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_532a6ffd8a21d3e98342fd401f0247f62ca4e038.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5344427df3ae9392c4fc4c25c232196828e70648.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5344427df3ae9392c4fc4c25c232196828e70648.hip new file mode 100644 index 000000000000..137e4967d035 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5344427df3ae9392c4fc4c25c232196828e70648.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5382a30dcf702daae19bd6705864bfe36e09502c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5382a30dcf702daae19bd6705864bfe36e09502c.hip new file mode 100644 index 000000000000..8aae4b5333d4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5382a30dcf702daae19bd6705864bfe36e09502c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_53bd60bd2afee49b30a583c32a45ae9f2076db08.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_53bd60bd2afee49b30a583c32a45ae9f2076db08.hip new file mode 100644 index 000000000000..9e573698aa1a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_53bd60bd2afee49b30a583c32a45ae9f2076db08.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5403eec1cdd216d5c4a7ba977e2ef92a0d7fcc8b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5403eec1cdd216d5c4a7ba977e2ef92a0d7fcc8b.hip new file mode 100644 index 000000000000..6970d7ae20dd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5403eec1cdd216d5c4a7ba977e2ef92a0d7fcc8b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_540bd57333c6839ccf5cf2e928edb996bc60c371.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_540bd57333c6839ccf5cf2e928edb996bc60c371.hip new file mode 100644 index 000000000000..f7df3f147d7f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_540bd57333c6839ccf5cf2e928edb996bc60c371.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_541874a7633e5713720b9d084b6d1c6715a51a17.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_541874a7633e5713720b9d084b6d1c6715a51a17.hip new file mode 100644 index 000000000000..06a5903fcd90 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_541874a7633e5713720b9d084b6d1c6715a51a17.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54208a6e8c5263e38f9ffcb062564ab61d2785ff.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54208a6e8c5263e38f9ffcb062564ab61d2785ff.hip new file mode 100644 index 000000000000..ac51939ea413 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54208a6e8c5263e38f9ffcb062564ab61d2785ff.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5435b4651a90e331fcdcf224282457e3dc038a30.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5435b4651a90e331fcdcf224282457e3dc038a30.hip new file mode 100644 index 000000000000..037237d41d0e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5435b4651a90e331fcdcf224282457e3dc038a30.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54402a22ceee3b665a3f24edb98b8398c35c6f5a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54402a22ceee3b665a3f24edb98b8398c35c6f5a.hip new file mode 100644 index 000000000000..ec0eb42fbdd6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54402a22ceee3b665a3f24edb98b8398c35c6f5a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54548ad36fb92d0963893146c8db20f53cbf0c8f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54548ad36fb92d0963893146c8db20f53cbf0c8f.hip new file mode 100644 index 000000000000..8a402676fd31 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54548ad36fb92d0963893146c8db20f53cbf0c8f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5467aea26852aa9a9e3dae76b906005ddf6fbae1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5467aea26852aa9a9e3dae76b906005ddf6fbae1.hip new file mode 100644 index 000000000000..14ce32a0c547 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5467aea26852aa9a9e3dae76b906005ddf6fbae1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_548b347672451e8391388a400d016803f4c4cf8d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_548b347672451e8391388a400d016803f4c4cf8d.hip new file mode 100644 index 000000000000..70204c41a3be --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_548b347672451e8391388a400d016803f4c4cf8d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54940ce53998becf9bddf56df7d19894a7658168.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54940ce53998becf9bddf56df7d19894a7658168.hip new file mode 100644 index 000000000000..94ba3e2921f4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54940ce53998becf9bddf56df7d19894a7658168.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_549b6956eaf678f7eb901567d1a515eddbedae5f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_549b6956eaf678f7eb901567d1a515eddbedae5f.hip new file mode 100644 index 000000000000..d809c4f367bb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_549b6956eaf678f7eb901567d1a515eddbedae5f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54b6e18b10d529eb6b32d7c19c59eaefc7184376.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54b6e18b10d529eb6b32d7c19c59eaefc7184376.hip new file mode 100644 index 000000000000..087c17dfc241 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54b6e18b10d529eb6b32d7c19c59eaefc7184376.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54ff49018f1c12b9fa31e523ad40b9cc162ba34d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54ff49018f1c12b9fa31e523ad40b9cc162ba34d.hip new file mode 100644 index 000000000000..916aab9e71a7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_54ff49018f1c12b9fa31e523ad40b9cc162ba34d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_555ba79201a585bc091ccfc326fd24e851d1eecc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_555ba79201a585bc091ccfc326fd24e851d1eecc.hip new file mode 100644 index 000000000000..6f75e3aadbdb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_555ba79201a585bc091ccfc326fd24e851d1eecc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_556cd05288e1666f5c67fb87ad02ce660e4c589c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_556cd05288e1666f5c67fb87ad02ce660e4c589c.hip new file mode 100644 index 000000000000..55a2358fa98b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_556cd05288e1666f5c67fb87ad02ce660e4c589c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55b14cf2998a61611d1de2594e926fcdc378999c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55b14cf2998a61611d1de2594e926fcdc378999c.hip new file mode 100644 index 000000000000..b9d2a7223c0d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55b14cf2998a61611d1de2594e926fcdc378999c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55bd9c4f1b7a0621c67f3e964d946ce22fb2fc80.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55bd9c4f1b7a0621c67f3e964d946ce22fb2fc80.hip new file mode 100644 index 000000000000..e9a73092d22c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55bd9c4f1b7a0621c67f3e964d946ce22fb2fc80.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55bf8444c1c26b91fd490c7216f4d0f8aa0a1f1a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55bf8444c1c26b91fd490c7216f4d0f8aa0a1f1a.hip new file mode 100644 index 000000000000..43bec7d74600 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55bf8444c1c26b91fd490c7216f4d0f8aa0a1f1a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55cda610c235987e13232e828f8d86fa88030560.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55cda610c235987e13232e828f8d86fa88030560.hip new file mode 100644 index 000000000000..b7a143108a74 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55cda610c235987e13232e828f8d86fa88030560.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55ea83a47c6299fefa4220ed88f7a8e1dd938215.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55ea83a47c6299fefa4220ed88f7a8e1dd938215.hip new file mode 100644 index 000000000000..beda973479b7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_55ea83a47c6299fefa4220ed88f7a8e1dd938215.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_566b4782793c6526bfce7362efbf6bf069928b2b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_566b4782793c6526bfce7362efbf6bf069928b2b.hip new file mode 100644 index 000000000000..b82f18838201 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_566b4782793c6526bfce7362efbf6bf069928b2b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_566e26d4969bc6bbe9b092bedab11cddb3360c0f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_566e26d4969bc6bbe9b092bedab11cddb3360c0f.hip new file mode 100644 index 000000000000..44eec56b09da --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_566e26d4969bc6bbe9b092bedab11cddb3360c0f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56964a17f902257aca9d08c736516a2c67d9a0e9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56964a17f902257aca9d08c736516a2c67d9a0e9.hip new file mode 100644 index 000000000000..bff7d9de06d2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56964a17f902257aca9d08c736516a2c67d9a0e9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56cc4399c5567a9495f17d54c712cc9e65e57521.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56cc4399c5567a9495f17d54c712cc9e65e57521.hip new file mode 100644 index 000000000000..aa735704459c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56cc4399c5567a9495f17d54c712cc9e65e57521.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56de9a7dfb1201b56528740e9d8a07b62710fcaf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56de9a7dfb1201b56528740e9d8a07b62710fcaf.hip new file mode 100644 index 000000000000..db75c7e723ed --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56de9a7dfb1201b56528740e9d8a07b62710fcaf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56ffe9e21362afe9c3a407c09d5de186954931a6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56ffe9e21362afe9c3a407c09d5de186954931a6.hip new file mode 100644 index 000000000000..9451a13d169d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_56ffe9e21362afe9c3a407c09d5de186954931a6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5724d91c1fd6290a6cf8d52a3801ac6b921dc7d4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5724d91c1fd6290a6cf8d52a3801ac6b921dc7d4.hip new file mode 100644 index 000000000000..591d9a51cd2f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5724d91c1fd6290a6cf8d52a3801ac6b921dc7d4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_572e68bd619e118292768f0925ccf92cbfa68415.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_572e68bd619e118292768f0925ccf92cbfa68415.hip new file mode 100644 index 000000000000..6c3797072934 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_572e68bd619e118292768f0925ccf92cbfa68415.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5732094f5917e9164ee0f973ac6ec47245a69101.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5732094f5917e9164ee0f973ac6ec47245a69101.hip new file mode 100644 index 000000000000..47b5b53a6b1d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5732094f5917e9164ee0f973ac6ec47245a69101.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5789f267d34c9961ced63ad07ffea2c6d2911415.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5789f267d34c9961ced63ad07ffea2c6d2911415.hip new file mode 100644 index 000000000000..0f617ac46e58 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5789f267d34c9961ced63ad07ffea2c6d2911415.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5854f09511778dd1779a839b0b194896070f69ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5854f09511778dd1779a839b0b194896070f69ad.hip new file mode 100644 index 000000000000..b0b63f90c985 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5854f09511778dd1779a839b0b194896070f69ad.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58679919fcd292a2a69543de0db94e2985c9d364.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58679919fcd292a2a69543de0db94e2985c9d364.hip new file mode 100644 index 000000000000..068ce698b3fb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58679919fcd292a2a69543de0db94e2985c9d364.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58762476c7f2bb05dce92ec22c0acbeb03676746.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58762476c7f2bb05dce92ec22c0acbeb03676746.hip new file mode 100644 index 000000000000..00d0f0a1541d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58762476c7f2bb05dce92ec22c0acbeb03676746.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_587fc33d02b1932235b8d152e57559060211d591.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_587fc33d02b1932235b8d152e57559060211d591.hip new file mode 100644 index 000000000000..186b55badcd3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_587fc33d02b1932235b8d152e57559060211d591.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58a784fb478ff5b3f1e2da9765a3a777efda92e3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58a784fb478ff5b3f1e2da9765a3a777efda92e3.hip new file mode 100644 index 000000000000..db36be36544a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58a784fb478ff5b3f1e2da9765a3a777efda92e3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58a7ab44bbd9fbc97c7805860d5f6ac81d6ae468.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58a7ab44bbd9fbc97c7805860d5f6ac81d6ae468.hip new file mode 100644 index 000000000000..399670f4e8ed --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58a7ab44bbd9fbc97c7805860d5f6ac81d6ae468.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58eb2edc7738d8d18ac359691da261ceaaf71788.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58eb2edc7738d8d18ac359691da261ceaaf71788.hip new file mode 100644 index 000000000000..28eaec0e9870 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_58eb2edc7738d8d18ac359691da261ceaaf71788.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5919133d2ed892745013b2fc5d503414cf0a4d83.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5919133d2ed892745013b2fc5d503414cf0a4d83.hip new file mode 100644 index 000000000000..82964aaa461e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5919133d2ed892745013b2fc5d503414cf0a4d83.hip @@ -0,0 +1,14399 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +#include + +template +float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + if(s.log_level_ > 0) + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << fmha_bwd_dq_dk_dv_get_name_() << ", " << fmha_bwd_convert_dq_get_name_() << std::flush; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){ fmha_bwd_dot_do_o_oneshot_(s_, a); }, + [=](const ck_tile::stream_config& s_){ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }, + [=](const ck_tile::stream_config& s_){ fmha_bwd_convert_dq_oneshot_(s_, a); } + ); +#else + return 0.0; +#endif +} + +float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){ + float r = -1; + if(t.data_type.compare("fp16") == 0){ + if (t.hdim_q <= 32 && t.hdim_v <= 32) { + if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + + } + else if (t.hdim_q <= 64 && t.hdim_v <= 64) { + if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + + } + else if (t.hdim_q <= 128 && t.hdim_v <= 128) { + if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + + } + else if (t.hdim_q <= 256 && t.hdim_v <= 256) { + if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::fp16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::fp16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + + } + + } + else if(t.data_type.compare("bf16") == 0){ + if (t.hdim_q <= 32 && t.hdim_v <= 32) { + if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 != 0) && (a.hdim_v % 32 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 32 == 0) && (a.hdim_v % 32 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<32, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<32, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + + } + else if (t.hdim_q <= 64 && t.hdim_v <= 64) { + if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 32 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 != 0) && (a.hdim_v % 64 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 64 == 0) && (a.hdim_v % 64 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<64, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + + } + else if (t.hdim_q <= 128 && t.hdim_v <= 128) { + if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 != 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 != 0) && (a.hdim_v % 128 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + + } + else if (t.hdim_q <= 256 && t.hdim_v <= 256) { + if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 != 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 != 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (a.seqlen_q % 64 == 0) && (a.seqlen_k % 64 == 0) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, false, false, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 != 0) && (a.hdim_v % 256 != 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, true>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, true, false>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == true)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, true>; + r = fmha_bwd_(s, a); + return r; + } + else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && (t.bias_type == bias_enum::alibi) && (t.has_dbias == false) && (t.has_dropout == true && t.is_store_randval == false) && + (true) && (true) && (a.hdim_q % 256 == 0) && (a.hdim_v % 256 == 0) && (t.is_deterministic == false)) { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, true, true, false>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<256, ck_tile::bf16_t, true, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, true, false, false, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, ck_tile::bf16_t, true, true, false, false>; + r = fmha_bwd_(s, a); + return r; + } + + } + + } + + return r; +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5939e6610e41aff8d1ccdb66d9e84d3e48e8d379.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5939e6610e41aff8d1ccdb66d9e84d3e48e8d379.hip new file mode 100644 index 000000000000..ca8bae010867 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5939e6610e41aff8d1ccdb66d9e84d3e48e8d379.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_594929c433b049a8cf949ff476309a8faf5c25fb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_594929c433b049a8cf949ff476309a8faf5c25fb.hip new file mode 100644 index 000000000000..b6cf1cb44116 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_594929c433b049a8cf949ff476309a8faf5c25fb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_597a0276ec419f18f060a5186e6bb703ae434ac8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_597a0276ec419f18f060a5186e6bb703ae434ac8.hip new file mode 100644 index 000000000000..e4098c9f7417 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_597a0276ec419f18f060a5186e6bb703ae434ac8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_59901147b7188212b8d8feea15831a11425fe4b3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_59901147b7188212b8d8feea15831a11425fe4b3.hip new file mode 100644 index 000000000000..f0a3ce4806db --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_59901147b7188212b8d8feea15831a11425fe4b3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_59beb9cb4e161f9dcff79080149076488d436301.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_59beb9cb4e161f9dcff79080149076488d436301.hip new file mode 100644 index 000000000000..544121f30b25 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_59beb9cb4e161f9dcff79080149076488d436301.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_59d366421e0b51c90fa53c366d47ed8d51b3a329.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_59d366421e0b51c90fa53c366d47ed8d51b3a329.hip new file mode 100644 index 000000000000..4f4e7b7ce420 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_59d366421e0b51c90fa53c366d47ed8d51b3a329.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a05b4e7782bd0e29ca9f6d33fc59d4304136d41.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a05b4e7782bd0e29ca9f6d33fc59d4304136d41.hip new file mode 100644 index 000000000000..ca7a148219aa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a05b4e7782bd0e29ca9f6d33fc59d4304136d41.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a216f777feec4752f5882677b18168225da4b53.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a216f777feec4752f5882677b18168225da4b53.hip new file mode 100644 index 000000000000..51c8dd8841ed --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a216f777feec4752f5882677b18168225da4b53.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a29b93cee012c79d4364502f1d90f947c73641d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a29b93cee012c79d4364502f1d90f947c73641d.hip new file mode 100644 index 000000000000..f837edec0250 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a29b93cee012c79d4364502f1d90f947c73641d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a85ae0a16e4b293b549bcb6a3ee52df7fccca32.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a85ae0a16e4b293b549bcb6a3ee52df7fccca32.hip new file mode 100644 index 000000000000..fe6ae2494351 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5a85ae0a16e4b293b549bcb6a3ee52df7fccca32.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5aba1183efe205af38e79a1b2dccea5fa515d02e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5aba1183efe205af38e79a1b2dccea5fa515d02e.hip new file mode 100644 index 000000000000..d63a67ffba90 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5aba1183efe205af38e79a1b2dccea5fa515d02e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5ace1c9b00f160a17355d4583d49c47887ac33c8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5ace1c9b00f160a17355d4583d49c47887ac33c8.hip new file mode 100644 index 000000000000..4e4e62b82a8b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5ace1c9b00f160a17355d4583d49c47887ac33c8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5af96b404feac271dac8f4190180754480d3ba80.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5af96b404feac271dac8f4190180754480d3ba80.hip new file mode 100644 index 000000000000..147c5f04af06 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5af96b404feac271dac8f4190180754480d3ba80.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5b413bdc825ae863d53dab548f2145dc0de8fd37.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5b413bdc825ae863d53dab548f2145dc0de8fd37.hip new file mode 100644 index 000000000000..e57559fe906e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5b413bdc825ae863d53dab548f2145dc0de8fd37.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, true, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5b55946ff3c15a44b9c741e9f6bbbcb5bd4c8577.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5b55946ff3c15a44b9c741e9f6bbbcb5bd4c8577.hip new file mode 100644 index 000000000000..072e763c87ef --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5b55946ff3c15a44b9c741e9f6bbbcb5bd4c8577.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5b7a4ea3bb8905a22ae97a94c354b1cbe38093bb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5b7a4ea3bb8905a22ae97a94c354b1cbe38093bb.hip new file mode 100644 index 000000000000..77e48a6748ff --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5b7a4ea3bb8905a22ae97a94c354b1cbe38093bb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5ba578c0e7abf1127dd0370f06d7278656c93ab9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5ba578c0e7abf1127dd0370f06d7278656c93ab9.hip new file mode 100644 index 000000000000..5c55f1f2502b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5ba578c0e7abf1127dd0370f06d7278656c93ab9.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, true, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5bc803342862aa30e23e5be7d84e611bc571c529.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5bc803342862aa30e23e5be7d84e611bc571c529.hip new file mode 100644 index 000000000000..9079d495034a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5bc803342862aa30e23e5be7d84e611bc571c529.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5be9ed84ad9be1627db7a66af9370679816c0897.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5be9ed84ad9be1627db7a66af9370679816c0897.hip new file mode 100644 index 000000000000..e9d1cecfd893 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5be9ed84ad9be1627db7a66af9370679816c0897.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5bead6be6e39ece0e5d44335083336f7f546d2f8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5bead6be6e39ece0e5d44335083336f7f546d2f8.hip new file mode 100644 index 000000000000..398cc46c34a4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5bead6be6e39ece0e5d44335083336f7f546d2f8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5c36fc744dfb0d985c9113175e76c7ec1c935054.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5c36fc744dfb0d985c9113175e76c7ec1c935054.hip new file mode 100644 index 000000000000..2fe7b7fef1f0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5c36fc744dfb0d985c9113175e76c7ec1c935054.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5c742b9ac6749f189d597ac97d46d35189472c50.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5c742b9ac6749f189d597ac97d46d35189472c50.hip new file mode 100644 index 000000000000..b98e8eb51d13 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5c742b9ac6749f189d597ac97d46d35189472c50.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5cd03e29403ad53d6d52e5e81182ea6ff5aff2be.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5cd03e29403ad53d6d52e5e81182ea6ff5aff2be.hip new file mode 100644 index 000000000000..b9bf06ceaf63 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5cd03e29403ad53d6d52e5e81182ea6ff5aff2be.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5cd41b6f578f3c903eb9d58ebfab62eb296044e0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5cd41b6f578f3c903eb9d58ebfab62eb296044e0.hip new file mode 100644 index 000000000000..15753ff0a0dd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5cd41b6f578f3c903eb9d58ebfab62eb296044e0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5d707d065ae152450f9def619ddc3dddb9089e88.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5d707d065ae152450f9def619ddc3dddb9089e88.hip new file mode 100644 index 000000000000..ac0e634a6ab4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5d707d065ae152450f9def619ddc3dddb9089e88.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5d7ed4c885fb32a0b548186e56d64bab98071d30.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5d7ed4c885fb32a0b548186e56d64bab98071d30.hip new file mode 100644 index 000000000000..2d294dc9e073 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5d7ed4c885fb32a0b548186e56d64bab98071d30.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5daedab8931f2eefb649b91e80145cb71b63360c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5daedab8931f2eefb649b91e80145cb71b63360c.hip new file mode 100644 index 000000000000..6bbb194c6f9c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5daedab8931f2eefb649b91e80145cb71b63360c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5de27c4081377f59363c2bf2ea8624217566d2d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5de27c4081377f59363c2bf2ea8624217566d2d3.hip new file mode 100644 index 000000000000..605258d502ec --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5de27c4081377f59363c2bf2ea8624217566d2d3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e0abf4e2b6be3e2c555c2134705b9dcaee617ce.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e0abf4e2b6be3e2c555c2134705b9dcaee617ce.hip new file mode 100644 index 000000000000..6375ef7d3f2d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e0abf4e2b6be3e2c555c2134705b9dcaee617ce.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e62968de58d9df7d687d671f37d63393f189321.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e62968de58d9df7d687d671f37d63393f189321.hip new file mode 100644 index 000000000000..8b7d6afbb742 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e62968de58d9df7d687d671f37d63393f189321.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e735b12d130ebf849ac5d6752e413ecf3e69fbf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e735b12d130ebf849ac5d6752e413ecf3e69fbf.hip new file mode 100644 index 000000000000..9aafba9d588b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e735b12d130ebf849ac5d6752e413ecf3e69fbf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e840be0741afa4d41fd4789c8300223fdc63ddc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e840be0741afa4d41fd4789c8300223fdc63ddc.hip new file mode 100644 index 000000000000..4713711a1efd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5e840be0741afa4d41fd4789c8300223fdc63ddc.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, true, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5ea53f7c6370845fa94aa9b395c52fd1900b62de.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5ea53f7c6370845fa94aa9b395c52fd1900b62de.hip new file mode 100644 index 000000000000..42f408026713 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5ea53f7c6370845fa94aa9b395c52fd1900b62de.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5efe77ca5c394a60af0313072cdd132216a52bf3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5efe77ca5c394a60af0313072cdd132216a52bf3.hip new file mode 100644 index 000000000000..b2d7b8a8c806 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5efe77ca5c394a60af0313072cdd132216a52bf3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f20263fd84776f155519b3481be5e2c5b035585.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f20263fd84776f155519b3481be5e2c5b035585.hip new file mode 100644 index 000000000000..49d85122882f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f20263fd84776f155519b3481be5e2c5b035585.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f3c3bed2b584ea2031debf9f953f5f8f7012171.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f3c3bed2b584ea2031debf9f953f5f8f7012171.hip new file mode 100644 index 000000000000..811bee38dadb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f3c3bed2b584ea2031debf9f953f5f8f7012171.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f71e663978dbcba859c5114ec675a712e343fd6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f71e663978dbcba859c5114ec675a712e343fd6.hip new file mode 100644 index 000000000000..317b0ffb52e3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f71e663978dbcba859c5114ec675a712e343fd6.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + false, + false, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f8925f929a5b26f3544ca31938aa75b3c59d34d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f8925f929a5b26f3544ca31938aa75b3c59d34d.hip new file mode 100644 index 000000000000..a46e8c9d1a06 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f8925f929a5b26f3544ca31938aa75b3c59d34d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f954a393b7b5a7131c13d0c4578443f468a738d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f954a393b7b5a7131c13d0c4578443f468a738d.hip new file mode 100644 index 000000000000..20ed87ac5b4c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5f954a393b7b5a7131c13d0c4578443f468a738d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fa19223cf296d7fd10e15e2571e63c84a80fbb1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fa19223cf296d7fd10e15e2571e63c84a80fbb1.hip new file mode 100644 index 000000000000..caf989f9a71b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fa19223cf296d7fd10e15e2571e63c84a80fbb1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fa7fafd4227918e0c7f0c6ca3b2bd673cd07279.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fa7fafd4227918e0c7f0c6ca3b2bd673cd07279.hip new file mode 100644 index 000000000000..9c7f4c738f26 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fa7fafd4227918e0c7f0c6ca3b2bd673cd07279.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fb062527121e627871b3f1b2a94b96c42e51205.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fb062527121e627871b3f1b2a94b96c42e51205.hip new file mode 100644 index 000000000000..86d4715e4aef --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fb062527121e627871b3f1b2a94b96c42e51205.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fc66c5b53f83bf1e023e81e9d51f0285b3ae731.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fc66c5b53f83bf1e023e81e9d51f0285b3ae731.hip new file mode 100644 index 000000000000..064f6e873cea --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_5fc66c5b53f83bf1e023e81e9d51f0285b3ae731.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6018ab272d7306689c7dc5a6d5326efea1471235.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6018ab272d7306689c7dc5a6d5326efea1471235.hip new file mode 100644 index 000000000000..8f71d31bf6a3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6018ab272d7306689c7dc5a6d5326efea1471235.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6049c01db99fce654e9351e711b113cf7424550a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6049c01db99fce654e9351e711b113cf7424550a.hip new file mode 100644 index 000000000000..64a88d00c1a0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6049c01db99fce654e9351e711b113cf7424550a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_606f5e0b99814b0a82a731de36f28024bc317801.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_606f5e0b99814b0a82a731de36f28024bc317801.hip new file mode 100644 index 000000000000..c844b00596f8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_606f5e0b99814b0a82a731de36f28024bc317801.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_60801d21c14796c08377349ec86a6c800af497b7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_60801d21c14796c08377349ec86a6c800af497b7.hip new file mode 100644 index 000000000000..3df31cb5546f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_60801d21c14796c08377349ec86a6c800af497b7.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6082d55544b5280b49b071ea277fb1827193fa2a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6082d55544b5280b49b071ea277fb1827193fa2a.hip new file mode 100644 index 000000000000..e643421efbc1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6082d55544b5280b49b071ea277fb1827193fa2a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_609616f72bf16a060fa50091ac139ddc06bf9d88.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_609616f72bf16a060fa50091ac139ddc06bf9d88.hip new file mode 100644 index 000000000000..1147254334ff --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_609616f72bf16a060fa50091ac139ddc06bf9d88.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_609f68180582384ba81aae2b1d4a4c52dde2c68c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_609f68180582384ba81aae2b1d4a4c52dde2c68c.hip new file mode 100644 index 000000000000..4e7b609594d4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_609f68180582384ba81aae2b1d4a4c52dde2c68c.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_60efa9c427dc278c0d1bc31189f683cd45e4d873.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_60efa9c427dc278c0d1bc31189f683cd45e4d873.hip new file mode 100644 index 000000000000..6797e80b7a47 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_60efa9c427dc278c0d1bc31189f683cd45e4d873.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61204f6805d5d830aa6fca2a9b5f238ed63c3a73.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61204f6805d5d830aa6fca2a9b5f238ed63c3a73.hip new file mode 100644 index 000000000000..c3fe918835d5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61204f6805d5d830aa6fca2a9b5f238ed63c3a73.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61220f6dca850a5b5ccf1f619a267c40c37efeca.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61220f6dca850a5b5ccf1f619a267c40c37efeca.hip new file mode 100644 index 000000000000..59421208ec88 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61220f6dca850a5b5ccf1f619a267c40c37efeca.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_614a9f10ebc51bde3f580ef527c17f89489c12c7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_614a9f10ebc51bde3f580ef527c17f89489c12c7.hip new file mode 100644 index 000000000000..2b9311b98371 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_614a9f10ebc51bde3f580ef527c17f89489c12c7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_615430cb65d8d540836c7f12b3367abd3c8e63d2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_615430cb65d8d540836c7f12b3367abd3c8e63d2.hip new file mode 100644 index 000000000000..ea187241a926 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_615430cb65d8d540836c7f12b3367abd3c8e63d2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_618031345ea71cc17e458eb97a559b7c94d3ae43.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_618031345ea71cc17e458eb97a559b7c94d3ae43.hip new file mode 100644 index 000000000000..1ad8d796b0fd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_618031345ea71cc17e458eb97a559b7c94d3ae43.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61896aa9e4e4d7e494c1755b1e77a08e0e264f8d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61896aa9e4e4d7e494c1755b1e77a08e0e264f8d.hip new file mode 100644 index 000000000000..11228066f779 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61896aa9e4e4d7e494c1755b1e77a08e0e264f8d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61a44ac409e914c12281f1d26e5b52d8bfd0df75.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61a44ac409e914c12281f1d26e5b52d8bfd0df75.hip new file mode 100644 index 000000000000..c0b760b06400 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61a44ac409e914c12281f1d26e5b52d8bfd0df75.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61a9e92183ba87924e73ff0b5e25bd12d6038e69.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61a9e92183ba87924e73ff0b5e25bd12d6038e69.hip new file mode 100644 index 000000000000..1ada930929eb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_61a9e92183ba87924e73ff0b5e25bd12d6038e69.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62048a8ae1c0096f3372b0114c15edbe813425fd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62048a8ae1c0096f3372b0114c15edbe813425fd.hip new file mode 100644 index 000000000000..1e3eea6a16ba --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62048a8ae1c0096f3372b0114c15edbe813425fd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6214f820b39a8ba81e547a78ed19a909ac13221c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6214f820b39a8ba81e547a78ed19a909ac13221c.hip new file mode 100644 index 000000000000..59a4cfac8c81 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6214f820b39a8ba81e547a78ed19a909ac13221c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_621da34ee666903307d3a09b7a032f2a70054759.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_621da34ee666903307d3a09b7a032f2a70054759.hip new file mode 100644 index 000000000000..73ed5ba2bcd2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_621da34ee666903307d3a09b7a032f2a70054759.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_628b28f65f19e7d1b22fb3b85b7cf3d09cd54ebc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_628b28f65f19e7d1b22fb3b85b7cf3d09cd54ebc.hip new file mode 100644 index 000000000000..37f7b6d970f3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_628b28f65f19e7d1b22fb3b85b7cf3d09cd54ebc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_629e0b97b3fece7c12504f4c8f1860d611b57269.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_629e0b97b3fece7c12504f4c8f1860d611b57269.hip new file mode 100644 index 000000000000..56114de5ec3e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_629e0b97b3fece7c12504f4c8f1860d611b57269.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62ab710e4acc711430745e05e036dd6a4d6bcdca.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62ab710e4acc711430745e05e036dd6a4d6bcdca.hip new file mode 100644 index 000000000000..f3e9617746e0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62ab710e4acc711430745e05e036dd6a4d6bcdca.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62ba7a5a0f3a714eb5f9f2af20f7bfbc82a30350.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62ba7a5a0f3a714eb5f9f2af20f7bfbc82a30350.hip new file mode 100644 index 000000000000..107eaac4bb6f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62ba7a5a0f3a714eb5f9f2af20f7bfbc82a30350.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62eb2f81e73d65fddce7ff43c397da6529317607.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62eb2f81e73d65fddce7ff43c397da6529317607.hip new file mode 100644 index 000000000000..d5268a6cfed3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_62eb2f81e73d65fddce7ff43c397da6529317607.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_634d530731c7ade2c7beecfd1bbbca8583032217.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_634d530731c7ade2c7beecfd1bbbca8583032217.hip new file mode 100644 index 000000000000..a9bf4ac89660 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_634d530731c7ade2c7beecfd1bbbca8583032217.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6360621af3f7e1e81a8be48fea8d2750fdecbbf4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6360621af3f7e1e81a8be48fea8d2750fdecbbf4.hip new file mode 100644 index 000000000000..bd58ae5b073c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6360621af3f7e1e81a8be48fea8d2750fdecbbf4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6376eb68c550b50b9aea42a7a2cc3bda186b0e40.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6376eb68c550b50b9aea42a7a2cc3bda186b0e40.hip new file mode 100644 index 000000000000..d7dcbfd10a22 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6376eb68c550b50b9aea42a7a2cc3bda186b0e40.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_63c411351ec59bdbed2590c599f9eddf7807b371.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_63c411351ec59bdbed2590c599f9eddf7807b371.hip new file mode 100644 index 000000000000..9a308c4fac7c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_63c411351ec59bdbed2590c599f9eddf7807b371.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_63f121a3c8928c10a2d86b487cd13fa995da670d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_63f121a3c8928c10a2d86b487cd13fa995da670d.hip new file mode 100644 index 000000000000..e9eb60654f2d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_63f121a3c8928c10a2d86b487cd13fa995da670d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_643b3798f11997d33ccb58d90ed6c10d5411b735.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_643b3798f11997d33ccb58d90ed6c10d5411b735.hip new file mode 100644 index 000000000000..314cb5f5ad78 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_643b3798f11997d33ccb58d90ed6c10d5411b735.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_649336d59a8b35919e593217b6fd4314a04ea359.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_649336d59a8b35919e593217b6fd4314a04ea359.hip new file mode 100644 index 000000000000..d39bc8d5dd3f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_649336d59a8b35919e593217b6fd4314a04ea359.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64a0ca185449a49fa485892fde6af745ba758167.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64a0ca185449a49fa485892fde6af745ba758167.hip new file mode 100644 index 000000000000..3aae70aeb0ed --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64a0ca185449a49fa485892fde6af745ba758167.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64b3488ddf3bb1a4870371882f0a5d267bdfdf73.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64b3488ddf3bb1a4870371882f0a5d267bdfdf73.hip new file mode 100644 index 000000000000..f3205430ef16 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64b3488ddf3bb1a4870371882f0a5d267bdfdf73.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64c3c1e3dac623f07c2dc1b934ccb868cafcb38c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64c3c1e3dac623f07c2dc1b934ccb868cafcb38c.hip new file mode 100644 index 000000000000..7f03548932d6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64c3c1e3dac623f07c2dc1b934ccb868cafcb38c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64cf03c0aa3f1b2a7b76b4e3418eb5063b982a29.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64cf03c0aa3f1b2a7b76b4e3418eb5063b982a29.hip new file mode 100644 index 000000000000..c5ef2de1b23c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64cf03c0aa3f1b2a7b76b4e3418eb5063b982a29.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64fe2db75cb20428856b02cd1cc8d7b393a6ad9c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64fe2db75cb20428856b02cd1cc8d7b393a6ad9c.hip new file mode 100644 index 000000000000..c90b78d5b365 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_64fe2db75cb20428856b02cd1cc8d7b393a6ad9c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_65794d9c185b21f59274ac5d4db10a7abc0be968.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_65794d9c185b21f59274ac5d4db10a7abc0be968.hip new file mode 100644 index 000000000000..16f8f996144e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_65794d9c185b21f59274ac5d4db10a7abc0be968.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_658552954505a2092662071401e135e84956c4c0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_658552954505a2092662071401e135e84956c4c0.hip new file mode 100644 index 000000000000..ecf63208bb11 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_658552954505a2092662071401e135e84956c4c0.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_65910c8b7a30acc731948ab58467fdbe4fe32f6d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_65910c8b7a30acc731948ab58467fdbe4fe32f6d.hip new file mode 100644 index 000000000000..ae319eab4af3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_65910c8b7a30acc731948ab58467fdbe4fe32f6d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_661b49505cfecbe4ec3e5c7371de3aaaa85ac9d5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_661b49505cfecbe4ec3e5c7371de3aaaa85ac9d5.hip new file mode 100644 index 000000000000..7d29797cc8a0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_661b49505cfecbe4ec3e5c7371de3aaaa85ac9d5.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_661ffaf653085dd7f122d603bb3ba4b001e5f3c0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_661ffaf653085dd7f122d603bb3ba4b001e5f3c0.hip new file mode 100644 index 000000000000..71d6d758ce8f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_661ffaf653085dd7f122d603bb3ba4b001e5f3c0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_662767e588220d0dc6137b00cc1d8dcc91e97134.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_662767e588220d0dc6137b00cc1d8dcc91e97134.hip new file mode 100644 index 000000000000..93eb241144db --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_662767e588220d0dc6137b00cc1d8dcc91e97134.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6649f19deeaea20663bee781af7edced7f7a4fc0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6649f19deeaea20663bee781af7edced7f7a4fc0.hip new file mode 100644 index 000000000000..09dcbd26f3b8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6649f19deeaea20663bee781af7edced7f7a4fc0.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66968bbf7e210911fcb95ba90c79837230ab1ce3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66968bbf7e210911fcb95ba90c79837230ab1ce3.hip new file mode 100644 index 000000000000..9aabd72f0b44 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66968bbf7e210911fcb95ba90c79837230ab1ce3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66a020f728df204ff51e37d2ddc21afb0aad5e7b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66a020f728df204ff51e37d2ddc21afb0aad5e7b.hip new file mode 100644 index 000000000000..6aef15d3c90f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66a020f728df204ff51e37d2ddc21afb0aad5e7b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66be70b088b20fc8de464167c35745461ddab640.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66be70b088b20fc8de464167c35745461ddab640.hip new file mode 100644 index 000000000000..eb2b20126c86 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66be70b088b20fc8de464167c35745461ddab640.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66f651d3415562206c1049b172261fddba01ea6c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66f651d3415562206c1049b172261fddba01ea6c.hip new file mode 100644 index 000000000000..951a884ab1ec --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_66f651d3415562206c1049b172261fddba01ea6c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_671828f15eec2a58be23063a1a8132d337cd26de.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_671828f15eec2a58be23063a1a8132d337cd26de.hip new file mode 100644 index 000000000000..d98041e8984a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_671828f15eec2a58be23063a1a8132d337cd26de.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + true, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6767cce35ab784aa42ebcb75af7305bc38a8721a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6767cce35ab784aa42ebcb75af7305bc38a8721a.hip new file mode 100644 index 000000000000..81b42c3c796c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6767cce35ab784aa42ebcb75af7305bc38a8721a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6785dcec0197fdbb50124ab06efa627f1a2c0567.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6785dcec0197fdbb50124ab06efa627f1a2c0567.hip new file mode 100644 index 000000000000..6c9bf31e43d7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6785dcec0197fdbb50124ab06efa627f1a2c0567.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_678a4a8210a972bb2ed89d6ac754fb79438ab2da.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_678a4a8210a972bb2ed89d6ac754fb79438ab2da.hip new file mode 100644 index 000000000000..258b8f41131f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_678a4a8210a972bb2ed89d6ac754fb79438ab2da.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_67fb736c61088b8dd92fe0371f5c98e23bf9077f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_67fb736c61088b8dd92fe0371f5c98e23bf9077f.hip new file mode 100644 index 000000000000..b913b516021b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_67fb736c61088b8dd92fe0371f5c98e23bf9077f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_680e81c3700f130df142c9a37a368944ca548721.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_680e81c3700f130df142c9a37a368944ca548721.hip new file mode 100644 index 000000000000..c0083ea217ae --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_680e81c3700f130df142c9a37a368944ca548721.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_683e8a33fdb7053760c9c135002b0a94facbe015.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_683e8a33fdb7053760c9c135002b0a94facbe015.hip new file mode 100644 index 000000000000..a6b28615b175 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_683e8a33fdb7053760c9c135002b0a94facbe015.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_687f4aaafd1a5b9ee85aadc6fab79ad0c27a2ea2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_687f4aaafd1a5b9ee85aadc6fab79ad0c27a2ea2.hip new file mode 100644 index 000000000000..97f8d43f7d85 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_687f4aaafd1a5b9ee85aadc6fab79ad0c27a2ea2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_688aaa193f332ed13e017e78ec07a7c80e45f6c5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_688aaa193f332ed13e017e78ec07a7c80e45f6c5.hip new file mode 100644 index 000000000000..7203357b70c9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_688aaa193f332ed13e017e78ec07a7c80e45f6c5.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6905ba47078abd7a5b6a51eb93b26095517e7f70.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6905ba47078abd7a5b6a51eb93b26095517e7f70.hip new file mode 100644 index 000000000000..60c5c65b11fa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6905ba47078abd7a5b6a51eb93b26095517e7f70.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_69214eb450c3b249017480efb8d092b0edad6dc3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_69214eb450c3b249017480efb8d092b0edad6dc3.hip new file mode 100644 index 000000000000..6ec78cd0c6fd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_69214eb450c3b249017480efb8d092b0edad6dc3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6979ef43adffdb62100270a62706fb811963925a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6979ef43adffdb62100270a62706fb811963925a.hip new file mode 100644 index 000000000000..2972744f96e0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6979ef43adffdb62100270a62706fb811963925a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_69cbe8eca7e3510f5caa7f13419cfbefbf031754.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_69cbe8eca7e3510f5caa7f13419cfbefbf031754.hip new file mode 100644 index 000000000000..ee857b9105fa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_69cbe8eca7e3510f5caa7f13419cfbefbf031754.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a3f42d5c9ccdd3807e488b00f02bc6ab5d8d99a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a3f42d5c9ccdd3807e488b00f02bc6ab5d8d99a.hip new file mode 100644 index 000000000000..052d1ed21afb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a3f42d5c9ccdd3807e488b00f02bc6ab5d8d99a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a4b6226b355bf35d4d07aaef1828091f03ad2ec.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a4b6226b355bf35d4d07aaef1828091f03ad2ec.hip new file mode 100644 index 000000000000..ba0aaab08700 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a4b6226b355bf35d4d07aaef1828091f03ad2ec.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a66604bb15f97a56847a7c968dbe32d247cbc13.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a66604bb15f97a56847a7c968dbe32d247cbc13.hip new file mode 100644 index 000000000000..17b8fd45e9bb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a66604bb15f97a56847a7c968dbe32d247cbc13.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a7b6781ffff9a42beebb4d73f0d15461ddd4479.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a7b6781ffff9a42beebb4d73f0d15461ddd4479.hip new file mode 100644 index 000000000000..2fc5358c55fa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a7b6781ffff9a42beebb4d73f0d15461ddd4479.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a7eb3d86aa385f9ecffbc5ba10489e56856f918.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a7eb3d86aa385f9ecffbc5ba10489e56856f918.hip new file mode 100644 index 000000000000..b303c35613d5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a7eb3d86aa385f9ecffbc5ba10489e56856f918.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a95543aeed81adfb6d847f78212585a36122ae3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a95543aeed81adfb6d847f78212585a36122ae3.hip new file mode 100644 index 000000000000..96a5e0049cb8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6a95543aeed81adfb6d847f78212585a36122ae3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6abeb7b50ae6a1fc62535b9a1dabbde6f177a9d0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6abeb7b50ae6a1fc62535b9a1dabbde6f177a9d0.hip new file mode 100644 index 000000000000..d10457b256f3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6abeb7b50ae6a1fc62535b9a1dabbde6f177a9d0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6af23d1460abfe875e71f7911697c42fef0f41c5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6af23d1460abfe875e71f7911697c42fef0f41c5.hip new file mode 100644 index 000000000000..8835123a6819 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6af23d1460abfe875e71f7911697c42fef0f41c5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6af4c15a119e805e4407b184625f57966f8833d9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6af4c15a119e805e4407b184625f57966f8833d9.hip new file mode 100644 index 000000000000..0f2bc69cb2fb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6af4c15a119e805e4407b184625f57966f8833d9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6b0ef67ce0f178aa2863c4909f5bdd7f766c9b2f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6b0ef67ce0f178aa2863c4909f5bdd7f766c9b2f.hip new file mode 100644 index 000000000000..5506cd001491 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6b0ef67ce0f178aa2863c4909f5bdd7f766c9b2f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6b638314efcc4f16aa4a6e58e6caf2fda1711519.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6b638314efcc4f16aa4a6e58e6caf2fda1711519.hip new file mode 100644 index 000000000000..65b9b5c79d53 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6b638314efcc4f16aa4a6e58e6caf2fda1711519.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6bad2ed9f91bc1efd89ea66cd5c775fa140cf931.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6bad2ed9f91bc1efd89ea66cd5c775fa140cf931.hip new file mode 100644 index 000000000000..58d654ccc049 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6bad2ed9f91bc1efd89ea66cd5c775fa140cf931.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6cfb7075345704340ff33dc0ef7c04ef127f26ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6cfb7075345704340ff33dc0ef7c04ef127f26ad.hip new file mode 100644 index 000000000000..bd3ee7d87a5f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6cfb7075345704340ff33dc0ef7c04ef127f26ad.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + false, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d07bf9c05e41dcf2416e05dab4bdde17158db76.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d07bf9c05e41dcf2416e05dab4bdde17158db76.hip new file mode 100644 index 000000000000..5c80eada962f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d07bf9c05e41dcf2416e05dab4bdde17158db76.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d17b92fab5bee7717bf9aff6a6bef7cee3816e7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d17b92fab5bee7717bf9aff6a6bef7cee3816e7.hip new file mode 100644 index 000000000000..35204ee02cd4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d17b92fab5bee7717bf9aff6a6bef7cee3816e7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d307974bdeeef95cca0d130ebb7aeb77fb1b6eb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d307974bdeeef95cca0d130ebb7aeb77fb1b6eb.hip new file mode 100644 index 000000000000..4c97ae6a1986 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d307974bdeeef95cca0d130ebb7aeb77fb1b6eb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d40d762ed576832b3a752453e9881b5fe6d2650.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d40d762ed576832b3a752453e9881b5fe6d2650.hip new file mode 100644 index 000000000000..79752788020d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d40d762ed576832b3a752453e9881b5fe6d2650.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d470f5c6fb81032fcd7974180297d4bb2a8427d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d470f5c6fb81032fcd7974180297d4bb2a8427d.hip new file mode 100644 index 000000000000..db765120f69a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d470f5c6fb81032fcd7974180297d4bb2a8427d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d5aad18f59e47a3fa3278c7ef1a6372830c33d5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d5aad18f59e47a3fa3278c7ef1a6372830c33d5.hip new file mode 100644 index 000000000000..bb8407f16c79 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6d5aad18f59e47a3fa3278c7ef1a6372830c33d5.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6db86621d626722434f2ae9b7b8ab435a8dd8827.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6db86621d626722434f2ae9b7b8ab435a8dd8827.hip new file mode 100644 index 000000000000..37f3d2facd37 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6db86621d626722434f2ae9b7b8ab435a8dd8827.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6dd707cf48a17d31abef94215c5720419faa0a39.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6dd707cf48a17d31abef94215c5720419faa0a39.hip new file mode 100644 index 000000000000..28ebbfe667b9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6dd707cf48a17d31abef94215c5720419faa0a39.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e240106c771ebea461fc2a87b6da68e510aba70.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e240106c771ebea461fc2a87b6da68e510aba70.hip new file mode 100644 index 000000000000..9856101cf819 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e240106c771ebea461fc2a87b6da68e510aba70.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e6a4475ea795935f4cbf2dc0ac156a33d754587.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e6a4475ea795935f4cbf2dc0ac156a33d754587.hip new file mode 100644 index 000000000000..62f6798d2243 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e6a4475ea795935f4cbf2dc0ac156a33d754587.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e7e1d245baabe2f6293e3d85318f9936b333500.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e7e1d245baabe2f6293e3d85318f9936b333500.hip new file mode 100644 index 000000000000..a8c34655cd73 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e7e1d245baabe2f6293e3d85318f9936b333500.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e8cda718e10824956f0ee39bbb0891eafa45a7b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e8cda718e10824956f0ee39bbb0891eafa45a7b.hip new file mode 100644 index 000000000000..4b2292110fbb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6e8cda718e10824956f0ee39bbb0891eafa45a7b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6eca9cd905ea8b0454cf9564643894682b08cb97.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6eca9cd905ea8b0454cf9564643894682b08cb97.hip new file mode 100644 index 000000000000..1ea5c1d12a47 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6eca9cd905ea8b0454cf9564643894682b08cb97.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6eebd0c2fbfc85f938b10535855c388971129a28.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6eebd0c2fbfc85f938b10535855c388971129a28.hip new file mode 100644 index 000000000000..471ba97a9daa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6eebd0c2fbfc85f938b10535855c388971129a28.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, false, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6ef5803b33d97db72eb8a8528aeb3fc956a938cc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6ef5803b33d97db72eb8a8528aeb3fc956a938cc.hip new file mode 100644 index 000000000000..b6e08fe9c100 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6ef5803b33d97db72eb8a8528aeb3fc956a938cc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f31b3345893eec8ed1ddf1d8de2512b46ff6187.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f31b3345893eec8ed1ddf1d8de2512b46ff6187.hip new file mode 100644 index 000000000000..3ecb92ff3a71 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f31b3345893eec8ed1ddf1d8de2512b46ff6187.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f3d098f8bb63133924aab70d26a6ed64018c13b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f3d098f8bb63133924aab70d26a6ed64018c13b.hip new file mode 100644 index 000000000000..4fd221d7dca7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f3d098f8bb63133924aab70d26a6ed64018c13b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f8788c537cbf6833c58a6ca15c0a36de33c9fbd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f8788c537cbf6833c58a6ca15c0a36de33c9fbd.hip new file mode 100644 index 000000000000..28862392bac2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f8788c537cbf6833c58a6ca15c0a36de33c9fbd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f88527a2cdb5adf51407f4661a254bb32d7de23.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f88527a2cdb5adf51407f4661a254bb32d7de23.hip new file mode 100644 index 000000000000..b7b0813e1bf1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6f88527a2cdb5adf51407f4661a254bb32d7de23.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6fa6478cc27e52fd9511fbff38369c921155cfb9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6fa6478cc27e52fd9511fbff38369c921155cfb9.hip new file mode 100644 index 000000000000..941581b5be80 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6fa6478cc27e52fd9511fbff38369c921155cfb9.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6ff4605d82507fc4bd6e96095eaee5173ea41973.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6ff4605d82507fc4bd6e96095eaee5173ea41973.hip new file mode 100644 index 000000000000..51026d282731 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6ff4605d82507fc4bd6e96095eaee5173ea41973.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6ff58a5186d69efd6062f3717bd315394ea6592b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6ff58a5186d69efd6062f3717bd315394ea6592b.hip new file mode 100644 index 000000000000..25945a70f72f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_6ff58a5186d69efd6062f3717bd315394ea6592b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_703246f1f53a988cf252eff88bdf814bd382d3ac.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_703246f1f53a988cf252eff88bdf814bd382d3ac.hip new file mode 100644 index 000000000000..5a4038d810bf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_703246f1f53a988cf252eff88bdf814bd382d3ac.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_70586668a61ab88bc46b763df8f1c2ea52001ea0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_70586668a61ab88bc46b763df8f1c2ea52001ea0.hip new file mode 100644 index 000000000000..b20806d0ece7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_70586668a61ab88bc46b763df8f1c2ea52001ea0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_70c8e45f6ea7cf5dba9eeadd0b19481d9f5defb7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_70c8e45f6ea7cf5dba9eeadd0b19481d9f5defb7.hip new file mode 100644 index 000000000000..405ee3c4d1a3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_70c8e45f6ea7cf5dba9eeadd0b19481d9f5defb7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_70cf755f1485c065222be4daab84283a9c3d0eb7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_70cf755f1485c065222be4daab84283a9c3d0eb7.hip new file mode 100644 index 000000000000..fe11bb581ade --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_70cf755f1485c065222be4daab84283a9c3d0eb7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_714c5369aa848021e020d874289e3ae4e0f74d77.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_714c5369aa848021e020d874289e3ae4e0f74d77.hip new file mode 100644 index 000000000000..79d5dd08b05d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_714c5369aa848021e020d874289e3ae4e0f74d77.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7177f939ac3dae8749cbf4232dcf04d2cf63b48f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7177f939ac3dae8749cbf4232dcf04d2cf63b48f.hip new file mode 100644 index 000000000000..f18be3af4cf4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7177f939ac3dae8749cbf4232dcf04d2cf63b48f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71a2d046629a4b65c90d0e18d061c4984062f844.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71a2d046629a4b65c90d0e18d061c4984062f844.hip new file mode 100644 index 000000000000..d834040de4e6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71a2d046629a4b65c90d0e18d061c4984062f844.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71b6100efe30d836dab557ea4ac54c4b9d35c6aa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71b6100efe30d836dab557ea4ac54c4b9d35c6aa.hip new file mode 100644 index 000000000000..d98aeafb88ff --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71b6100efe30d836dab557ea4ac54c4b9d35c6aa.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71dcbe9f481c92215f3b636bc0e86ce8f65e6472.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71dcbe9f481c92215f3b636bc0e86ce8f65e6472.hip new file mode 100644 index 000000000000..579a23f60929 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71dcbe9f481c92215f3b636bc0e86ce8f65e6472.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71e3980331dc4bcec6ab6f4c345c7b5f71356979.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71e3980331dc4bcec6ab6f4c345c7b5f71356979.hip new file mode 100644 index 000000000000..c4be9d2eb58f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71e3980331dc4bcec6ab6f4c345c7b5f71356979.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71e5fb3544dafa9da03fd2de4bb9bd0718f6009f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71e5fb3544dafa9da03fd2de4bb9bd0718f6009f.hip new file mode 100644 index 000000000000..a28fce33b7b0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_71e5fb3544dafa9da03fd2de4bb9bd0718f6009f.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7237ce5f3cf13ace3efc0b0227ae5a8c1fdfce1d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7237ce5f3cf13ace3efc0b0227ae5a8c1fdfce1d.hip new file mode 100644 index 000000000000..3baa424043e8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7237ce5f3cf13ace3efc0b0227ae5a8c1fdfce1d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_724d1d4408196d611b2e0535bf8833652acbd6ef.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_724d1d4408196d611b2e0535bf8833652acbd6ef.hip new file mode 100644 index 000000000000..85adc2a66a65 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_724d1d4408196d611b2e0535bf8833652acbd6ef.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + false, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7264e378e1ea1d4dd97f6949d66f3492883b663e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7264e378e1ea1d4dd97f6949d66f3492883b663e.hip new file mode 100644 index 000000000000..e2bbd10801e3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7264e378e1ea1d4dd97f6949d66f3492883b663e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_72abb25dba0c48b380b2dabeb6ab7efaa706d180.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_72abb25dba0c48b380b2dabeb6ab7efaa706d180.hip new file mode 100644 index 000000000000..832b29ff44e0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_72abb25dba0c48b380b2dabeb6ab7efaa706d180.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7309c38fc8a2d5ad6efd449107dc54a7509624fe.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7309c38fc8a2d5ad6efd449107dc54a7509624fe.hip new file mode 100644 index 000000000000..252c8f7d0dd6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7309c38fc8a2d5ad6efd449107dc54a7509624fe.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7344f96bed2f56793b1c2583485aa161cdf30379.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7344f96bed2f56793b1c2583485aa161cdf30379.hip new file mode 100644 index 000000000000..01c9d16f3a08 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7344f96bed2f56793b1c2583485aa161cdf30379.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7393267865f1c2b0aa1a09a586f54cec98eea4ae.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7393267865f1c2b0aa1a09a586f54cec98eea4ae.hip new file mode 100644 index 000000000000..28017c42ec79 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7393267865f1c2b0aa1a09a586f54cec98eea4ae.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_73d4901b8ef034590314048de7223a572d61ee0f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_73d4901b8ef034590314048de7223a572d61ee0f.hip new file mode 100644 index 000000000000..3e13537d3d00 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_73d4901b8ef034590314048de7223a572d61ee0f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_73ec21ed6e040260c4f04ef68ef9307aa86985a7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_73ec21ed6e040260c4f04ef68ef9307aa86985a7.hip new file mode 100644 index 000000000000..5423fe85c9da --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_73ec21ed6e040260c4f04ef68ef9307aa86985a7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_741401abfbbbdf0dd1d62df8bc3e85371ead71d6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_741401abfbbbdf0dd1d62df8bc3e85371ead71d6.hip new file mode 100644 index 000000000000..f5c10ce4d21f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_741401abfbbbdf0dd1d62df8bc3e85371ead71d6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_743176ecb1f0bc800c870861585edf56f88d7739.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_743176ecb1f0bc800c870861585edf56f88d7739.hip new file mode 100644 index 000000000000..5707e0019825 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_743176ecb1f0bc800c870861585edf56f88d7739.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_744ec604c577a27e0aae5b39711a9e2eb82801b6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_744ec604c577a27e0aae5b39711a9e2eb82801b6.hip new file mode 100644 index 000000000000..2b705c532a1c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_744ec604c577a27e0aae5b39711a9e2eb82801b6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_745705ae121a1a331527cedfe4d31218a428a0df.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_745705ae121a1a331527cedfe4d31218a428a0df.hip new file mode 100644 index 000000000000..b01539a5d08e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_745705ae121a1a331527cedfe4d31218a428a0df.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_748a3d76e8ab73af9a5d2302d33e3b1d1b866dd1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_748a3d76e8ab73af9a5d2302d33e3b1d1b866dd1.hip new file mode 100644 index 000000000000..c16fef433fe0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_748a3d76e8ab73af9a5d2302d33e3b1d1b866dd1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7497eca4d1a18306b406b367653622a8d64095bf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7497eca4d1a18306b406b367653622a8d64095bf.hip new file mode 100644 index 000000000000..c50336b6148b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7497eca4d1a18306b406b367653622a8d64095bf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_74ba59d347ce8916a22b40e6f22a3c89e13db4d0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_74ba59d347ce8916a22b40e6f22a3c89e13db4d0.hip new file mode 100644 index 000000000000..3cbcfca0c779 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_74ba59d347ce8916a22b40e6f22a3c89e13db4d0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_74d5f2aef029f2103bb419cc982cae99fd1a9253.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_74d5f2aef029f2103bb419cc982cae99fd1a9253.hip new file mode 100644 index 000000000000..9ef9694e78ff --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_74d5f2aef029f2103bb419cc982cae99fd1a9253.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7524904ac5a2040c7ea72aef5942212f291a21bf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7524904ac5a2040c7ea72aef5942212f291a21bf.hip new file mode 100644 index 000000000000..1e1bdfb1bd6e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7524904ac5a2040c7ea72aef5942212f291a21bf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_758b211174da0f398b2a093e7389905b4f9c4060.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_758b211174da0f398b2a093e7389905b4f9c4060.hip new file mode 100644 index 000000000000..ae87a4aa231d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_758b211174da0f398b2a093e7389905b4f9c4060.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7596c14b8fee751d03f42ca48ea4f66e87fc2e2f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7596c14b8fee751d03f42ca48ea4f66e87fc2e2f.hip new file mode 100644 index 000000000000..b146cf408456 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7596c14b8fee751d03f42ca48ea4f66e87fc2e2f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7597ce4d2e5264bdeda47487d5bdb55a014c6616.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7597ce4d2e5264bdeda47487d5bdb55a014c6616.hip new file mode 100644 index 000000000000..8b7074e99905 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7597ce4d2e5264bdeda47487d5bdb55a014c6616.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75a310a6eb86e3e8baac7a930c3ffbef372942b3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75a310a6eb86e3e8baac7a930c3ffbef372942b3.hip new file mode 100644 index 000000000000..cf727977b736 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75a310a6eb86e3e8baac7a930c3ffbef372942b3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75c38912947881caa14b3fc7ab7bca317e296dc3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75c38912947881caa14b3fc7ab7bca317e296dc3.hip new file mode 100644 index 000000000000..1effa7fb7fc2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75c38912947881caa14b3fc7ab7bca317e296dc3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75f2010bf6c478d2f0eba77e912697661306c1cb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75f2010bf6c478d2f0eba77e912697661306c1cb.hip new file mode 100644 index 000000000000..fb3ddc019fc2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75f2010bf6c478d2f0eba77e912697661306c1cb.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75f21e38ad01fade35b1db40adabd75eb602410c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75f21e38ad01fade35b1db40adabd75eb602410c.hip new file mode 100644 index 000000000000..ca43e2867ac4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_75f21e38ad01fade35b1db40adabd75eb602410c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7601e6aea44b96e94fb019501be6b102c6e6a654.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7601e6aea44b96e94fb019501be6b102c6e6a654.hip new file mode 100644 index 000000000000..0f8822144108 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7601e6aea44b96e94fb019501be6b102c6e6a654.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_761bde840c0c8149b24a8f6f264e963c4e9e8ceb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_761bde840c0c8149b24a8f6f264e963c4e9e8ceb.hip new file mode 100644 index 000000000000..181b498ce72c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_761bde840c0c8149b24a8f6f264e963c4e9e8ceb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_765940baaaa2ae6ade43ef4c94a220eaa63702b0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_765940baaaa2ae6ade43ef4c94a220eaa63702b0.hip new file mode 100644 index 000000000000..9bba041444d1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_765940baaaa2ae6ade43ef4c94a220eaa63702b0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76674fc182dfa6329c73a354aa3adf458429444a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76674fc182dfa6329c73a354aa3adf458429444a.hip new file mode 100644 index 000000000000..cee13840f3db --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76674fc182dfa6329c73a354aa3adf458429444a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76704ca28a4877a1e84022e022614709adabb280.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76704ca28a4877a1e84022e022614709adabb280.hip new file mode 100644 index 000000000000..b9fa502c84d4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76704ca28a4877a1e84022e022614709adabb280.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_768c80fd3ea17813df1bf19a158186834fd00780.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_768c80fd3ea17813df1bf19a158186834fd00780.hip new file mode 100644 index 000000000000..93a77a8b9d1f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_768c80fd3ea17813df1bf19a158186834fd00780.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76be322fc072ca19baa82707e260c6eba936ae19.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76be322fc072ca19baa82707e260c6eba936ae19.hip new file mode 100644 index 000000000000..b2b778250784 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76be322fc072ca19baa82707e260c6eba936ae19.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76f884e9ca116ee47b446efe9fc770c178a858d5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76f884e9ca116ee47b446efe9fc770c178a858d5.hip new file mode 100644 index 000000000000..dfc3921c02eb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_76f884e9ca116ee47b446efe9fc770c178a858d5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_770ad1eb1b30ad8f1e7c17df486093129b2d5630.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_770ad1eb1b30ad8f1e7c17df486093129b2d5630.hip new file mode 100644 index 000000000000..edfcfeb1cc37 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_770ad1eb1b30ad8f1e7c17df486093129b2d5630.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_77200e875e0ef160b311c7de450c137772312d0d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_77200e875e0ef160b311c7de450c137772312d0d.hip new file mode 100644 index 000000000000..e8423351ee7c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_77200e875e0ef160b311c7de450c137772312d0d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_772016803aa3ca6ebe785557118365f9be7c4339.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_772016803aa3ca6ebe785557118365f9be7c4339.hip new file mode 100644 index 000000000000..5a047976a5f6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_772016803aa3ca6ebe785557118365f9be7c4339.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7726be8909f631c04d4395fa4ffd03a736f447f1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7726be8909f631c04d4395fa4ffd03a736f447f1.hip new file mode 100644 index 000000000000..62579cdaf7f0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7726be8909f631c04d4395fa4ffd03a736f447f1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7728d5bec7941c9b6d5632bee8d67ed92b9c03ec.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7728d5bec7941c9b6d5632bee8d67ed92b9c03ec.hip new file mode 100644 index 000000000000..ce79ac4a6338 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7728d5bec7941c9b6d5632bee8d67ed92b9c03ec.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7764814a0de7702f0b7b5ce9dede6440603f4853.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7764814a0de7702f0b7b5ce9dede6440603f4853.hip new file mode 100644 index 000000000000..ca718c6ce1f5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7764814a0de7702f0b7b5ce9dede6440603f4853.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_77a814291d8f01870274149b9d82fb75921d6e20.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_77a814291d8f01870274149b9d82fb75921d6e20.hip new file mode 100644 index 000000000000..147cdc179f14 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_77a814291d8f01870274149b9d82fb75921d6e20.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_77d0223697ed41c4c2fd8830f8df6e5620db547f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_77d0223697ed41c4c2fd8830f8df6e5620db547f.hip new file mode 100644 index 000000000000..555847bb1d1a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_77d0223697ed41c4c2fd8830f8df6e5620db547f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7831ce329f2a0812ebb1dd103ea4ba8cb7ba531d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7831ce329f2a0812ebb1dd103ea4ba8cb7ba531d.hip new file mode 100644 index 000000000000..35207bc71156 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7831ce329f2a0812ebb1dd103ea4ba8cb7ba531d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7838849e57ee9cd292e588f587a8079b57becfc8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7838849e57ee9cd292e588f587a8079b57becfc8.hip new file mode 100644 index 000000000000..b445dda9b5f2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7838849e57ee9cd292e588f587a8079b57becfc8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_783ec08544591a22f59dc12f169b7327b4185a1a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_783ec08544591a22f59dc12f169b7327b4185a1a.hip new file mode 100644 index 000000000000..0e7b5e2b256d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_783ec08544591a22f59dc12f169b7327b4185a1a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_784c35fee4d372123631312f1051c43e1fa12378.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_784c35fee4d372123631312f1051c43e1fa12378.hip new file mode 100644 index 000000000000..3b96cb2f6da0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_784c35fee4d372123631312f1051c43e1fa12378.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78663faeb0425f45e8a0da0f7b1a5ddbee5e07e7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78663faeb0425f45e8a0da0f7b1a5ddbee5e07e7.hip new file mode 100644 index 000000000000..7fa113622d50 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78663faeb0425f45e8a0da0f7b1a5ddbee5e07e7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7872c45ba170f2782c4b5b75cfc78ac79a4cf157.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7872c45ba170f2782c4b5b75cfc78ac79a4cf157.hip new file mode 100644 index 000000000000..14b1e581ef1c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7872c45ba170f2782c4b5b75cfc78ac79a4cf157.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7878e2a4d3b96a552e03d1ffc33debfd50c9f7f1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7878e2a4d3b96a552e03d1ffc33debfd50c9f7f1.hip new file mode 100644 index 000000000000..81f8717ec473 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7878e2a4d3b96a552e03d1ffc33debfd50c9f7f1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78e1edca5abe1bb3e7aa946eab6484b7bed806a3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78e1edca5abe1bb3e7aa946eab6484b7bed806a3.hip new file mode 100644 index 000000000000..d847abfc5e27 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78e1edca5abe1bb3e7aa946eab6484b7bed806a3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78e945db4afa1330fe3978bc1bc9ae99828ae287.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78e945db4afa1330fe3978bc1bc9ae99828ae287.hip new file mode 100644 index 000000000000..cfa83bac96b3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78e945db4afa1330fe3978bc1bc9ae99828ae287.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78f7e2a2c08cd87702793f91b6935cbe4c22be55.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78f7e2a2c08cd87702793f91b6935cbe4c22be55.hip new file mode 100644 index 000000000000..4ac9b91617a2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_78f7e2a2c08cd87702793f91b6935cbe4c22be55.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_797750ac0b18b48f56ceb4640256e9bd3a36621a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_797750ac0b18b48f56ceb4640256e9bd3a36621a.hip new file mode 100644 index 000000000000..c01aa886d333 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_797750ac0b18b48f56ceb4640256e9bd3a36621a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7993fc08ac5c6ce7a2eceb1227f4e3718dc4cf5f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7993fc08ac5c6ce7a2eceb1227f4e3718dc4cf5f.hip new file mode 100644 index 000000000000..36ef130d4014 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7993fc08ac5c6ce7a2eceb1227f4e3718dc4cf5f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_79a7dce707954e765d97cb22e57d9bd6168860d9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_79a7dce707954e765d97cb22e57d9bd6168860d9.hip new file mode 100644 index 000000000000..60a766676023 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_79a7dce707954e765d97cb22e57d9bd6168860d9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_79d0b8053ddf99a4d4447656d733c2da026b3a7c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_79d0b8053ddf99a4d4447656d733c2da026b3a7c.hip new file mode 100644 index 000000000000..cd262a1319ba --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_79d0b8053ddf99a4d4447656d733c2da026b3a7c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_79f182ae021e23869d7bebf2a9b4575bdc910ed0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_79f182ae021e23869d7bebf2a9b4575bdc910ed0.hip new file mode 100644 index 000000000000..b1bf4ec0f3c8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_79f182ae021e23869d7bebf2a9b4575bdc910ed0.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a0ab620e6d62259a559e329460e46e6e3f7c3f9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a0ab620e6d62259a559e329460e46e6e3f7c3f9.hip new file mode 100644 index 000000000000..741db9df0c53 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a0ab620e6d62259a559e329460e46e6e3f7c3f9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a13d62a715fd717f0d4101f787349cb49cbe70f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a13d62a715fd717f0d4101f787349cb49cbe70f.hip new file mode 100644 index 000000000000..603e08729210 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a13d62a715fd717f0d4101f787349cb49cbe70f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a242e5953f44316b6a4f6587ec26283ed6cbcae.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a242e5953f44316b6a4f6587ec26283ed6cbcae.hip new file mode 100644 index 000000000000..b93e01af0ce6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a242e5953f44316b6a4f6587ec26283ed6cbcae.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a2e032f6500fbc5468183415b6dd1d3e43f0bee.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a2e032f6500fbc5468183415b6dd1d3e43f0bee.hip new file mode 100644 index 000000000000..e693144bf20d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a2e032f6500fbc5468183415b6dd1d3e43f0bee.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a890b126da2d8cfbf84f048b779cac2dd56b509.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a890b126da2d8cfbf84f048b779cac2dd56b509.hip new file mode 100644 index 000000000000..c3275fe13362 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a890b126da2d8cfbf84f048b779cac2dd56b509.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a902ed4ae3cc6558c73b730ff3949778007a230.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a902ed4ae3cc6558c73b730ff3949778007a230.hip new file mode 100644 index 000000000000..56793e3969e9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7a902ed4ae3cc6558c73b730ff3949778007a230.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7aa14aa94d625b33df1adfa30ef4d91769592608.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7aa14aa94d625b33df1adfa30ef4d91769592608.hip new file mode 100644 index 000000000000..200e27910675 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7aa14aa94d625b33df1adfa30ef4d91769592608.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ab03a62e064864e1e9c1cd506c1b2e1786a777c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ab03a62e064864e1e9c1cd506c1b2e1786a777c.hip new file mode 100644 index 000000000000..ad200b3bc3be --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ab03a62e064864e1e9c1cd506c1b2e1786a777c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7adf69b51f0a8cc9ae7e250e60df38758230fe4f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7adf69b51f0a8cc9ae7e250e60df38758230fe4f.hip new file mode 100644 index 000000000000..2022023dd507 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7adf69b51f0a8cc9ae7e250e60df38758230fe4f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7afd1a756247b15b078d15a39e350a07c22982da.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7afd1a756247b15b078d15a39e350a07c22982da.hip new file mode 100644 index 000000000000..9738d2e4d62c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7afd1a756247b15b078d15a39e350a07c22982da.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b2d3680c3578c7292349b58843aef7a82e0087d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b2d3680c3578c7292349b58843aef7a82e0087d.hip new file mode 100644 index 000000000000..4a409a802f6d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b2d3680c3578c7292349b58843aef7a82e0087d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b5680f97836be4a369802e8115617a83875703e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b5680f97836be4a369802e8115617a83875703e.hip new file mode 100644 index 000000000000..affd984987db --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b5680f97836be4a369802e8115617a83875703e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b67045d438a7e4b8f3a313a5df5a85f351c1be5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b67045d438a7e4b8f3a313a5df5a85f351c1be5.hip new file mode 100644 index 000000000000..ec8213154f61 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b67045d438a7e4b8f3a313a5df5a85f351c1be5.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b7fa76609243a8709f349ffc0d9d88157f28dc9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b7fa76609243a8709f349ffc0d9d88157f28dc9.hip new file mode 100644 index 000000000000..b667ec694e6f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b7fa76609243a8709f349ffc0d9d88157f28dc9.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b9a3bf1a9b37e0bd9bae6249609e5994dc0dba1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b9a3bf1a9b37e0bd9bae6249609e5994dc0dba1.hip new file mode 100644 index 000000000000..c13bfb9dd925 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7b9a3bf1a9b37e0bd9bae6249609e5994dc0dba1.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7bb7b63e8a4c1df4eac4d978e166867195bd6e53.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7bb7b63e8a4c1df4eac4d978e166867195bd6e53.hip new file mode 100644 index 000000000000..73fbe5ebab1f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7bb7b63e8a4c1df4eac4d978e166867195bd6e53.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c19fc90e5a9c422dbf529d2def286f47dea0f50.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c19fc90e5a9c422dbf529d2def286f47dea0f50.hip new file mode 100644 index 000000000000..80f364bf21bd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c19fc90e5a9c422dbf529d2def286f47dea0f50.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c23dde1a386436e9864c8fa5f1706c0d2fbfd0d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c23dde1a386436e9864c8fa5f1706c0d2fbfd0d.hip new file mode 100644 index 000000000000..9184e7f08e11 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c23dde1a386436e9864c8fa5f1706c0d2fbfd0d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c3d8ef4da515960bf40eb1feb04d21950ad5ae5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c3d8ef4da515960bf40eb1feb04d21950ad5ae5.hip new file mode 100644 index 000000000000..222afbb4e77b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c3d8ef4da515960bf40eb1feb04d21950ad5ae5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c4710e8f4e27fae4ae079f1667c3a1879cb6da8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c4710e8f4e27fae4ae079f1667c3a1879cb6da8.hip new file mode 100644 index 000000000000..5526e7868aeb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7c4710e8f4e27fae4ae079f1667c3a1879cb6da8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7cbe4562c51d6829ec5942e11035c452fe318b3a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7cbe4562c51d6829ec5942e11035c452fe318b3a.hip new file mode 100644 index 000000000000..0d8237272563 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7cbe4562c51d6829ec5942e11035c452fe318b3a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7cdc419d4248dfdeeab1f0980aec35fa134e52e0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7cdc419d4248dfdeeab1f0980aec35fa134e52e0.hip new file mode 100644 index 000000000000..89319da79dc7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7cdc419d4248dfdeeab1f0980aec35fa134e52e0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d08373ace7087bdaca4ce8b0bc329f553f88d77.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d08373ace7087bdaca4ce8b0bc329f553f88d77.hip new file mode 100644 index 000000000000..a2c9587f665e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d08373ace7087bdaca4ce8b0bc329f553f88d77.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d0f767c17385eb7d756cbe8ed444d7cef72dea5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d0f767c17385eb7d756cbe8ed444d7cef72dea5.hip new file mode 100644 index 000000000000..721a61bf546f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d0f767c17385eb7d756cbe8ed444d7cef72dea5.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, false, false, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d12e9cb599d24631c082e3cf65d2c58b6d4d44f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d12e9cb599d24631c082e3cf65d2c58b6d4d44f.hip new file mode 100644 index 000000000000..56d732bc792d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d12e9cb599d24631c082e3cf65d2c58b6d4d44f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d2f87c021e0b6a27b2d7e30351fd50f06414b5f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d2f87c021e0b6a27b2d7e30351fd50f06414b5f.hip new file mode 100644 index 000000000000..81f03941a6a2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d2f87c021e0b6a27b2d7e30351fd50f06414b5f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d5667b27f15a06d4040354fba3601d48bb9c045.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d5667b27f15a06d4040354fba3601d48bb9c045.hip new file mode 100644 index 000000000000..041654a04733 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7d5667b27f15a06d4040354fba3601d48bb9c045.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7dac5d4cf103d658e129673549549f1276f134e0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7dac5d4cf103d658e129673549549f1276f134e0.hip new file mode 100644 index 000000000000..977dd8080e9b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7dac5d4cf103d658e129673549549f1276f134e0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7dd260849b86c46b685955cab54ba07d49b47954.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7dd260849b86c46b685955cab54ba07d49b47954.hip new file mode 100644 index 000000000000..05280ff531bf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7dd260849b86c46b685955cab54ba07d49b47954.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ddd621da88c57798db1e689b93b692b6519ff96.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ddd621da88c57798db1e689b93b692b6519ff96.hip new file mode 100644 index 000000000000..7e2415b296c3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ddd621da88c57798db1e689b93b692b6519ff96.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7dfe21ee27f8a0ca0407ef0dea73cd73ae6940db.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7dfe21ee27f8a0ca0407ef0dea73cd73ae6940db.hip new file mode 100644 index 000000000000..72b7af002ff1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7dfe21ee27f8a0ca0407ef0dea73cd73ae6940db.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e1bdde812c332c9fc58613698568a04771b9fa8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e1bdde812c332c9fc58613698568a04771b9fa8.hip new file mode 100644 index 000000000000..7ba6eedf761d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e1bdde812c332c9fc58613698568a04771b9fa8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e332a6aeecfb12dcf70c69157fd3137343fb9f6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e332a6aeecfb12dcf70c69157fd3137343fb9f6.hip new file mode 100644 index 000000000000..9c6f0c8905c3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e332a6aeecfb12dcf70c69157fd3137343fb9f6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e6129eead18d13a4a6cb9550384fddabc7a2a16.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e6129eead18d13a4a6cb9550384fddabc7a2a16.hip new file mode 100644 index 000000000000..68838c8d32ee --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e6129eead18d13a4a6cb9550384fddabc7a2a16.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e89f79217037e361bb0909d06534e40f5026b4f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e89f79217037e361bb0909d06534e40f5026b4f.hip new file mode 100644 index 000000000000..93f3475e6e40 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e89f79217037e361bb0909d06534e40f5026b4f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e9519dd0d0f940fd5efd61bd32df7528ba7e3fc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e9519dd0d0f940fd5efd61bd32df7528ba7e3fc.hip new file mode 100644 index 000000000000..31c94c7f898b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e9519dd0d0f940fd5efd61bd32df7528ba7e3fc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e9c7feb747241c9c7de2adf3a19933a1c4c0995.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e9c7feb747241c9c7de2adf3a19933a1c4c0995.hip new file mode 100644 index 000000000000..ef7c125928f5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7e9c7feb747241c9c7de2adf3a19933a1c4c0995.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ea9c37d92e344f3cc58cd4d1d00f19167e3623e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ea9c37d92e344f3cc58cd4d1d00f19167e3623e.hip new file mode 100644 index 000000000000..f8a361ce637b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ea9c37d92e344f3cc58cd4d1d00f19167e3623e.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ec038393ec329a894aee9bbac078a40f57a4684.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ec038393ec329a894aee9bbac078a40f57a4684.hip new file mode 100644 index 000000000000..886343649f5b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ec038393ec329a894aee9bbac078a40f57a4684.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ec04763d635c5bc3e810737b5d948c59f117d5a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ec04763d635c5bc3e810737b5d948c59f117d5a.hip new file mode 100644 index 000000000000..32f2eb1c249f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ec04763d635c5bc3e810737b5d948c59f117d5a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ee953cb24e28bcdc8f05783894b23cbf83bdf35.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ee953cb24e28bcdc8f05783894b23cbf83bdf35.hip new file mode 100644 index 000000000000..98bc5dde59e9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ee953cb24e28bcdc8f05783894b23cbf83bdf35.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f6ccdb3c2d595fffd05bc5e6417b157276547fb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f6ccdb3c2d595fffd05bc5e6417b157276547fb.hip new file mode 100644 index 000000000000..255a436c7d87 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f6ccdb3c2d595fffd05bc5e6417b157276547fb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f80d44e82e601dc48d4c8b4e710ef7265894b6c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f80d44e82e601dc48d4c8b4e710ef7265894b6c.hip new file mode 100644 index 000000000000..97b29f4190da --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f80d44e82e601dc48d4c8b4e710ef7265894b6c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f9403cb91d6aabebf081afae94a8ba397d8d24f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f9403cb91d6aabebf081afae94a8ba397d8d24f.hip new file mode 100644 index 000000000000..869517b0f736 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f9403cb91d6aabebf081afae94a8ba397d8d24f.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + false, + false, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f9bb3486fee7b7c9e24300b8a4e4ce88a11bfc0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f9bb3486fee7b7c9e24300b8a4e4ce88a11bfc0.hip new file mode 100644 index 000000000000..d4b65bf16020 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7f9bb3486fee7b7c9e24300b8a4e4ce88a11bfc0.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7fa76fc1b066a15b08dc6c24a7cf33a58b4cb6cb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7fa76fc1b066a15b08dc6c24a7cf33a58b4cb6cb.hip new file mode 100644 index 000000000000..60a12bd2f99e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7fa76fc1b066a15b08dc6c24a7cf33a58b4cb6cb.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7fe409f4421193fb48a54aa5f26bd6229d23204c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7fe409f4421193fb48a54aa5f26bd6229d23204c.hip new file mode 100644 index 000000000000..e493c2e5207c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7fe409f4421193fb48a54aa5f26bd6229d23204c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ff65c7abd9b0d8a2df9302d6dc167637b3a72f0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ff65c7abd9b0d8a2df9302d6dc167637b3a72f0.hip new file mode 100644 index 000000000000..a171deece2ae --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_7ff65c7abd9b0d8a2df9302d6dc167637b3a72f0.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8004763f674dfb3f14b66dfdeb2a046e413ce2cb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8004763f674dfb3f14b66dfdeb2a046e413ce2cb.hip new file mode 100644 index 000000000000..3ce9605b6ed2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8004763f674dfb3f14b66dfdeb2a046e413ce2cb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8007bf7ae1b71bf8ac4a793aa519ad333aa7a7ba.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8007bf7ae1b71bf8ac4a793aa519ad333aa7a7ba.hip new file mode 100644 index 000000000000..2d6b7f56d244 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8007bf7ae1b71bf8ac4a793aa519ad333aa7a7ba.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8021fa266c77e6b5bd1af2a9c22c686e5a6eac78.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8021fa266c77e6b5bd1af2a9c22c686e5a6eac78.hip new file mode 100644 index 000000000000..6ee5fecbdfb0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8021fa266c77e6b5bd1af2a9c22c686e5a6eac78.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_802b21f9588d72c3c3e3b9a3b269f19c484d5aa4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_802b21f9588d72c3c3e3b9a3b269f19c484d5aa4.hip new file mode 100644 index 000000000000..740a14b38963 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_802b21f9588d72c3c3e3b9a3b269f19c484d5aa4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8046f566fa7188c92568b277354e8b06ad382544.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8046f566fa7188c92568b277354e8b06ad382544.hip new file mode 100644 index 000000000000..e8f7155fdb0a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8046f566fa7188c92568b277354e8b06ad382544.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_806f9ab9baf631df1d3a8d801e4cf93a102526cf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_806f9ab9baf631df1d3a8d801e4cf93a102526cf.hip new file mode 100644 index 000000000000..668edd0effc5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_806f9ab9baf631df1d3a8d801e4cf93a102526cf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_807545400aa6e70ff49a5f38ed6a218a180bd87f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_807545400aa6e70ff49a5f38ed6a218a180bd87f.hip new file mode 100644 index 000000000000..a0ae88c8061f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_807545400aa6e70ff49a5f38ed6a218a180bd87f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80987e2d765efc320eaee813607c94c80ee35aa4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80987e2d765efc320eaee813607c94c80ee35aa4.hip new file mode 100644 index 000000000000..742463346af5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80987e2d765efc320eaee813607c94c80ee35aa4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80a72d70d80b66c19e85daa00497308381050048.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80a72d70d80b66c19e85daa00497308381050048.hip new file mode 100644 index 000000000000..df68604c46fd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80a72d70d80b66c19e85daa00497308381050048.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80bfb0e6032892cc58cef4dd403f305a5b76851b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80bfb0e6032892cc58cef4dd403f305a5b76851b.hip new file mode 100644 index 000000000000..2c57edb1b307 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80bfb0e6032892cc58cef4dd403f305a5b76851b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80cf0997573f4bcfbaaf75e40f519580a7495a17.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80cf0997573f4bcfbaaf75e40f519580a7495a17.hip new file mode 100644 index 000000000000..5f901c56ff4c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80cf0997573f4bcfbaaf75e40f519580a7495a17.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80efc341089a50ed5669b3c86f6ddd9b124d1442.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80efc341089a50ed5669b3c86f6ddd9b124d1442.hip new file mode 100644 index 000000000000..44e9c2e4c6cd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80efc341089a50ed5669b3c86f6ddd9b124d1442.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80f51f0e178c33e6196df1d2e47bd38bf5391cc8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80f51f0e178c33e6196df1d2e47bd38bf5391cc8.hip new file mode 100644 index 000000000000..3f856fea1f3e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80f51f0e178c33e6196df1d2e47bd38bf5391cc8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80fb694fce7b4c3c459fca43c89c6002fbfdaef5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80fb694fce7b4c3c459fca43c89c6002fbfdaef5.hip new file mode 100644 index 000000000000..bd032e27668e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_80fb694fce7b4c3c459fca43c89c6002fbfdaef5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_810dd4e870ceda3ba9b5f0084a4b025b2e609d57.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_810dd4e870ceda3ba9b5f0084a4b025b2e609d57.hip new file mode 100644 index 000000000000..ef304c578a79 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_810dd4e870ceda3ba9b5f0084a4b025b2e609d57.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_811db756577b61cde9fe8279d956980db9ee21a4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_811db756577b61cde9fe8279d956980db9ee21a4.hip new file mode 100644 index 000000000000..564d8374617b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_811db756577b61cde9fe8279d956980db9ee21a4.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + true, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_813e60e8405aca3f7fbed19452ae37574ada9a77.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_813e60e8405aca3f7fbed19452ae37574ada9a77.hip new file mode 100644 index 000000000000..e5d2004adf29 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_813e60e8405aca3f7fbed19452ae37574ada9a77.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_815918206483d2ae04a45aa67d69dfb986587214.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_815918206483d2ae04a45aa67d69dfb986587214.hip new file mode 100644 index 000000000000..17be14af25b7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_815918206483d2ae04a45aa67d69dfb986587214.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_816c48e129a0235cb3a19124ddb28cce286fb368.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_816c48e129a0235cb3a19124ddb28cce286fb368.hip new file mode 100644 index 000000000000..b90714501b09 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_816c48e129a0235cb3a19124ddb28cce286fb368.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81acf1d17650712b71a499bb66909bfcfcb6aecb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81acf1d17650712b71a499bb66909bfcfcb6aecb.hip new file mode 100644 index 000000000000..06a07db689ac --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81acf1d17650712b71a499bb66909bfcfcb6aecb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81bb8f13b6f20a72c9ce6d0b53f81eddbf05f1c6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81bb8f13b6f20a72c9ce6d0b53f81eddbf05f1c6.hip new file mode 100644 index 000000000000..b5b601d7d53d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81bb8f13b6f20a72c9ce6d0b53f81eddbf05f1c6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81dd3ea61bb61de02667b14f5a94198f48c7307b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81dd3ea61bb61de02667b14f5a94198f48c7307b.hip new file mode 100644 index 000000000000..1ea30c92f144 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81dd3ea61bb61de02667b14f5a94198f48c7307b.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + true, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81f6c575c3fa2ccc7e65022f1ba65c8cfc16541e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81f6c575c3fa2ccc7e65022f1ba65c8cfc16541e.hip new file mode 100644 index 000000000000..c348e217f75d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_81f6c575c3fa2ccc7e65022f1ba65c8cfc16541e.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82048cf91270631f98ac37dc488a1fb2e00ce004.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82048cf91270631f98ac37dc488a1fb2e00ce004.hip new file mode 100644 index 000000000000..ef460fe8a8c3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82048cf91270631f98ac37dc488a1fb2e00ce004.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8250f27341241086515d833aa53ae873d4ece3fa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8250f27341241086515d833aa53ae873d4ece3fa.hip new file mode 100644 index 000000000000..611502df7311 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8250f27341241086515d833aa53ae873d4ece3fa.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8278845045d68027dcf3bf867ecde2fb12ec51d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8278845045d68027dcf3bf867ecde2fb12ec51d3.hip new file mode 100644 index 000000000000..d95a9874bd83 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8278845045d68027dcf3bf867ecde2fb12ec51d3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82ad0c0580516485ea432d98f53e73f6dfec548c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82ad0c0580516485ea432d98f53e73f6dfec548c.hip new file mode 100644 index 000000000000..48f4d0aed933 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82ad0c0580516485ea432d98f53e73f6dfec548c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82c932e6eaaf44861c794539d9caf8b50192fc44.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82c932e6eaaf44861c794539d9caf8b50192fc44.hip new file mode 100644 index 000000000000..de5126d11e17 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82c932e6eaaf44861c794539d9caf8b50192fc44.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82d7f61e6313930f063758b61102e7a43b118beb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82d7f61e6313930f063758b61102e7a43b118beb.hip new file mode 100644 index 000000000000..23af435c18b9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82d7f61e6313930f063758b61102e7a43b118beb.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82f0f3d71108dcc49234a258f0f3b21ea2123cc0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82f0f3d71108dcc49234a258f0f3b21ea2123cc0.hip new file mode 100644 index 000000000000..d37bf9ec4a14 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82f0f3d71108dcc49234a258f0f3b21ea2123cc0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82f1d7e1a93bf2fa80c409e6827ea88af56c44f0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82f1d7e1a93bf2fa80c409e6827ea88af56c44f0.hip new file mode 100644 index 000000000000..cb1531a3738c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_82f1d7e1a93bf2fa80c409e6827ea88af56c44f0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8301bfc0394936a68fa0098580f06e77c88ebed9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8301bfc0394936a68fa0098580f06e77c88ebed9.hip new file mode 100644 index 000000000000..2f3395e4f404 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8301bfc0394936a68fa0098580f06e77c88ebed9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83080406598df6bd3102db70a554e496e29db96a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83080406598df6bd3102db70a554e496e29db96a.hip new file mode 100644 index 000000000000..3b17343a9684 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83080406598df6bd3102db70a554e496e29db96a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_830e3532f27b391585d5de90f3bdf97992b67651.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_830e3532f27b391585d5de90f3bdf97992b67651.hip new file mode 100644 index 000000000000..d132dabefd03 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_830e3532f27b391585d5de90f3bdf97992b67651.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8352031044ef2e4a22e27ad04ab5d2c02121faee.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8352031044ef2e4a22e27ad04ab5d2c02121faee.hip new file mode 100644 index 000000000000..7e5865723fca --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8352031044ef2e4a22e27ad04ab5d2c02121faee.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_835a906031a258c6362313eec783678bd8125c91.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_835a906031a258c6362313eec783678bd8125c91.hip new file mode 100644 index 000000000000..86c405bdf4e3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_835a906031a258c6362313eec783678bd8125c91.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_836a308c2d2afd6e0dfbfda61984b631c4ccffc6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_836a308c2d2afd6e0dfbfda61984b631c4ccffc6.hip new file mode 100644 index 000000000000..d5d7249e38f9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_836a308c2d2afd6e0dfbfda61984b631c4ccffc6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83d580a612af85533c87aecdd7b0345c71b75980.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83d580a612af85533c87aecdd7b0345c71b75980.hip new file mode 100644 index 000000000000..cca0753895fc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83d580a612af85533c87aecdd7b0345c71b75980.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83d920a76114c63156740ba5dd6f3846c4b21c28.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83d920a76114c63156740ba5dd6f3846c4b21c28.hip new file mode 100644 index 000000000000..63cc76c2bc77 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83d920a76114c63156740ba5dd6f3846c4b21c28.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83ddca2c6ecbba4314c434e7471ffb8fa642f936.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83ddca2c6ecbba4314c434e7471ffb8fa642f936.hip new file mode 100644 index 000000000000..69d86ffa6cf9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83ddca2c6ecbba4314c434e7471ffb8fa642f936.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83f6a1837a65df12b7c55d25ca28cc939c2a6328.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83f6a1837a65df12b7c55d25ca28cc939c2a6328.hip new file mode 100644 index 000000000000..b53e8a7fea8b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_83f6a1837a65df12b7c55d25ca28cc939c2a6328.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_843e7888cba5f463d19fcb71aaaab25dc3d2c09d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_843e7888cba5f463d19fcb71aaaab25dc3d2c09d.hip new file mode 100644 index 000000000000..4311527376ac --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_843e7888cba5f463d19fcb71aaaab25dc3d2c09d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8441910c34830ad2459fb85c2c14af02da718fdc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8441910c34830ad2459fb85c2c14af02da718fdc.hip new file mode 100644 index 000000000000..49df5646dcc8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8441910c34830ad2459fb85c2c14af02da718fdc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8457ea5726149efb8778e6d90798b8e48288fc9a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8457ea5726149efb8778e6d90798b8e48288fc9a.hip new file mode 100644 index 000000000000..1f44e17f5f6e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8457ea5726149efb8778e6d90798b8e48288fc9a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_847feaf237911478173377a501ee19ee325b012b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_847feaf237911478173377a501ee19ee325b012b.hip new file mode 100644 index 000000000000..10e961bddae5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_847feaf237911478173377a501ee19ee325b012b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84cca7528c7d1bf49ba79625733ff0ae7522c096.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84cca7528c7d1bf49ba79625733ff0ae7522c096.hip new file mode 100644 index 000000000000..e68e0004da8e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84cca7528c7d1bf49ba79625733ff0ae7522c096.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84dc4af43de08130a04bfa06df9799b6e9e96900.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84dc4af43de08130a04bfa06df9799b6e9e96900.hip new file mode 100644 index 000000000000..2631aea8563d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84dc4af43de08130a04bfa06df9799b6e9e96900.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84e8ae99e184013739019c93d07caddce532382b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84e8ae99e184013739019c93d07caddce532382b.hip new file mode 100644 index 000000000000..81a62ac01ad7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84e8ae99e184013739019c93d07caddce532382b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84fc5e94f89d6a9287cf64662a372784511468dd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84fc5e94f89d6a9287cf64662a372784511468dd.hip new file mode 100644 index 000000000000..f9b3c6036394 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_84fc5e94f89d6a9287cf64662a372784511468dd.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8513d96a66a4d9fb8dfc84afba7e1d8c200248a6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8513d96a66a4d9fb8dfc84afba7e1d8c200248a6.hip new file mode 100644 index 000000000000..f626aec3121c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8513d96a66a4d9fb8dfc84afba7e1d8c200248a6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85156f2c556c6ef6180608c361b7b35ede71ffea.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85156f2c556c6ef6180608c361b7b35ede71ffea.hip new file mode 100644 index 000000000000..747b26b3171f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85156f2c556c6ef6180608c361b7b35ede71ffea.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_854c8003a508ed3f8cbe6967c4ae2635a491c721.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_854c8003a508ed3f8cbe6967c4ae2635a491c721.hip new file mode 100644 index 000000000000..56d7d4ab9bfe --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_854c8003a508ed3f8cbe6967c4ae2635a491c721.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85908fe6dc9c629c82d6953081b10021e64583b1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85908fe6dc9c629c82d6953081b10021e64583b1.hip new file mode 100644 index 000000000000..e3c6e6402752 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85908fe6dc9c629c82d6953081b10021e64583b1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85960fe542635079de5eca3c7785890cd4740005.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85960fe542635079de5eca3c7785890cd4740005.hip new file mode 100644 index 000000000000..b0136c222605 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85960fe542635079de5eca3c7785890cd4740005.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85fdde4b25e2fc8cbdd46c2850c19eac8d9af8f6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85fdde4b25e2fc8cbdd46c2850c19eac8d9af8f6.hip new file mode 100644 index 000000000000..4d88eeb5a8f8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_85fdde4b25e2fc8cbdd46c2850c19eac8d9af8f6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86309c036d96367939ccc3e8922595ac35a3e179.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86309c036d96367939ccc3e8922595ac35a3e179.hip new file mode 100644 index 000000000000..d500d2e7da3b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86309c036d96367939ccc3e8922595ac35a3e179.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86513d6e065a44bcb0c789eed1e7e5456e800ab6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86513d6e065a44bcb0c789eed1e7e5456e800ab6.hip new file mode 100644 index 000000000000..7af5f97ebabe --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86513d6e065a44bcb0c789eed1e7e5456e800ab6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_865eb90b1a2d64acc0f6fbe1d807c501fd4be3cd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_865eb90b1a2d64acc0f6fbe1d807c501fd4be3cd.hip new file mode 100644 index 000000000000..20805e32fa9b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_865eb90b1a2d64acc0f6fbe1d807c501fd4be3cd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8689126a7eb09d81baaf8f99dbff8932fbeab3cb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8689126a7eb09d81baaf8f99dbff8932fbeab3cb.hip new file mode 100644 index 000000000000..4a32d7b4c69a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8689126a7eb09d81baaf8f99dbff8932fbeab3cb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86d73393d0d8b769f30222f7817563a955c36dfc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86d73393d0d8b769f30222f7817563a955c36dfc.hip new file mode 100644 index 000000000000..ee406c5d6d87 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86d73393d0d8b769f30222f7817563a955c36dfc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86fa51b8c7a2f3fac5cf4cd2951ed2ede5c35450.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86fa51b8c7a2f3fac5cf4cd2951ed2ede5c35450.hip new file mode 100644 index 000000000000..61e7e816ce77 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_86fa51b8c7a2f3fac5cf4cd2951ed2ede5c35450.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_875b08ca602fe48840c72cd61798acb98540fcd6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_875b08ca602fe48840c72cd61798acb98540fcd6.hip new file mode 100644 index 000000000000..b6c5e9518a1f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_875b08ca602fe48840c72cd61798acb98540fcd6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_876a418fbe6183d0392b7a7d9986d067e323e2b9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_876a418fbe6183d0392b7a7d9986d067e323e2b9.hip new file mode 100644 index 000000000000..90aee59be116 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_876a418fbe6183d0392b7a7d9986d067e323e2b9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_877e33463b3bf1853c6d2d2009af8d27bf88abbe.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_877e33463b3bf1853c6d2d2009af8d27bf88abbe.hip new file mode 100644 index 000000000000..1f369ec28396 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_877e33463b3bf1853c6d2d2009af8d27bf88abbe.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8793dc3217e154b65ebba065aa10ab4dc2374ae8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8793dc3217e154b65ebba065aa10ab4dc2374ae8.hip new file mode 100644 index 000000000000..2cc6aa370a71 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8793dc3217e154b65ebba065aa10ab4dc2374ae8.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::fp16_t, true, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_87e3a06266deda093bdf28af82d8666066157fc6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_87e3a06266deda093bdf28af82d8666066157fc6.hip new file mode 100644 index 000000000000..83887203cff1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_87e3a06266deda093bdf28af82d8666066157fc6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8840e8899b4e632714632450bcef001c6070f955.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8840e8899b4e632714632450bcef001c6070f955.hip new file mode 100644 index 000000000000..28da9b996aa2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8840e8899b4e632714632450bcef001c6070f955.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88ac7f6cbdfca2e397bcb86af4216e87166601c7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88ac7f6cbdfca2e397bcb86af4216e87166601c7.hip new file mode 100644 index 000000000000..e5f75f10395a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88ac7f6cbdfca2e397bcb86af4216e87166601c7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88c04463f9c5ce565a9daa8c22e16de80fadd707.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88c04463f9c5ce565a9daa8c22e16de80fadd707.hip new file mode 100644 index 000000000000..2f0c3065dd22 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88c04463f9c5ce565a9daa8c22e16de80fadd707.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88d52c5f70abb525b9c8aa8fc1cb3997c33ed67c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88d52c5f70abb525b9c8aa8fc1cb3997c33ed67c.hip new file mode 100644 index 000000000000..6b05a1501a71 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88d52c5f70abb525b9c8aa8fc1cb3997c33ed67c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88ea5b5346c87cc4fc1e841c518080df4ab811a2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88ea5b5346c87cc4fc1e841c518080df4ab811a2.hip new file mode 100644 index 000000000000..42776dda70cb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88ea5b5346c87cc4fc1e841c518080df4ab811a2.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + true, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88ed7f650c958a644c8031aeb88688b1e42458e5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88ed7f650c958a644c8031aeb88688b1e42458e5.hip new file mode 100644 index 000000000000..fe7b3a8ea309 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_88ed7f650c958a644c8031aeb88688b1e42458e5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_890aa875ac13957f00b30210477924697abf0c9e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_890aa875ac13957f00b30210477924697abf0c9e.hip new file mode 100644 index 000000000000..642ee740a655 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_890aa875ac13957f00b30210477924697abf0c9e.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_89617bdea526d12d6a33ed42b9b0018c0b173722.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_89617bdea526d12d6a33ed42b9b0018c0b173722.hip new file mode 100644 index 000000000000..ebfb3e4ce54a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_89617bdea526d12d6a33ed42b9b0018c0b173722.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_89a3327da9a3411ff1cddc67eb647083cd947a92.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_89a3327da9a3411ff1cddc67eb647083cd947a92.hip new file mode 100644 index 000000000000..6c1dfc913dc4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_89a3327da9a3411ff1cddc67eb647083cd947a92.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a1fd28acfe85b3adac859c4bbffa4d28fe634fe.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a1fd28acfe85b3adac859c4bbffa4d28fe634fe.hip new file mode 100644 index 000000000000..a66c9b2f45a0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a1fd28acfe85b3adac859c4bbffa4d28fe634fe.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a58d4bca33c4c0e79141a56688049237d170d1b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a58d4bca33c4c0e79141a56688049237d170d1b.hip new file mode 100644 index 000000000000..422fabd80a5d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a58d4bca33c4c0e79141a56688049237d170d1b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a824621a50cdc3cbadc4b1f9ef18e1325385082.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a824621a50cdc3cbadc4b1f9ef18e1325385082.hip new file mode 100644 index 000000000000..8ca426d1bea3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a824621a50cdc3cbadc4b1f9ef18e1325385082.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a980749c6b2a18c80426dd189e5506334343ca4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a980749c6b2a18c80426dd189e5506334343ca4.hip new file mode 100644 index 000000000000..d9041d10891d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8a980749c6b2a18c80426dd189e5506334343ca4.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8adbdcd28cb2f078f89adf9aad2b3d4a0a477823.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8adbdcd28cb2f078f89adf9aad2b3d4a0a477823.hip new file mode 100644 index 000000000000..96c9a760bbdd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8adbdcd28cb2f078f89adf9aad2b3d4a0a477823.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8b17c082f249649eca733a8f0cdf9a1205c3e3d7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8b17c082f249649eca733a8f0cdf9a1205c3e3d7.hip new file mode 100644 index 000000000000..74b3ade12186 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8b17c082f249649eca733a8f0cdf9a1205c3e3d7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8b9043572cabb65435627a3faf23b18d039bbcd8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8b9043572cabb65435627a3faf23b18d039bbcd8.hip new file mode 100644 index 000000000000..a3b036096ec8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8b9043572cabb65435627a3faf23b18d039bbcd8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8b92990df507e82f96eeb7aa3ec00c01437566fb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8b92990df507e82f96eeb7aa3ec00c01437566fb.hip new file mode 100644 index 000000000000..b2bfb22e16ed --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8b92990df507e82f96eeb7aa3ec00c01437566fb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8bd1a40b12ce927323594fcce61eb9c20cc5e3d4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8bd1a40b12ce927323594fcce61eb9c20cc5e3d4.hip new file mode 100644 index 000000000000..a9897db92281 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8bd1a40b12ce927323594fcce61eb9c20cc5e3d4.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + false, + false, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8bd7b8c63a51c8639b3cf27ad09d41ae47c480d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8bd7b8c63a51c8639b3cf27ad09d41ae47c480d3.hip new file mode 100644 index 000000000000..cb223c4050b6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8bd7b8c63a51c8639b3cf27ad09d41ae47c480d3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c074afcf33e3f3534ac3577484237fcfd2ca48e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c074afcf33e3f3534ac3577484237fcfd2ca48e.hip new file mode 100644 index 000000000000..1f04d3a1af92 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c074afcf33e3f3534ac3577484237fcfd2ca48e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c13c4f3f645a2bb475eb1c55ce1de452f0e2332.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c13c4f3f645a2bb475eb1c55ce1de452f0e2332.hip new file mode 100644 index 000000000000..ac9d48c4fb0d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c13c4f3f645a2bb475eb1c55ce1de452f0e2332.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c3bd4e029bba76ebfc79e6522dbc8ca0bba5dd2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c3bd4e029bba76ebfc79e6522dbc8ca0bba5dd2.hip new file mode 100644 index 000000000000..7cd1eab74909 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c3bd4e029bba76ebfc79e6522dbc8ca0bba5dd2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c4688cbd23727dd0ea9a36fb977b31aeae98d65.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c4688cbd23727dd0ea9a36fb977b31aeae98d65.hip new file mode 100644 index 000000000000..b3292ab6d77e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c4688cbd23727dd0ea9a36fb977b31aeae98d65.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c7970957024de050748d3e31cef434f582d968b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c7970957024de050748d3e31cef434f582d968b.hip new file mode 100644 index 000000000000..45a342f057d4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8c7970957024de050748d3e31cef434f582d968b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8cdcdeb845e7bcdb89ef70ab2a97157d4db3cb52.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8cdcdeb845e7bcdb89ef70ab2a97157d4db3cb52.hip new file mode 100644 index 000000000000..a14bb66e2b54 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8cdcdeb845e7bcdb89ef70ab2a97157d4db3cb52.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8cf1007430da272174d3476d042f398627e83512.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8cf1007430da272174d3476d042f398627e83512.hip new file mode 100644 index 000000000000..ad4263d4eb17 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8cf1007430da272174d3476d042f398627e83512.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8d079c1eb36db8461fa8b861c56760afcd97cc34.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8d079c1eb36db8461fa8b861c56760afcd97cc34.hip new file mode 100644 index 000000000000..adef882533f9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8d079c1eb36db8461fa8b861c56760afcd97cc34.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8d7549e66ef309e32779ddc2a1f14e79bae53754.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8d7549e66ef309e32779ddc2a1f14e79bae53754.hip new file mode 100644 index 000000000000..d2e731ca5550 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8d7549e66ef309e32779ddc2a1f14e79bae53754.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8d79fe8a600c3b4e0ec9aa510f8036ba2b608985.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8d79fe8a600c3b4e0ec9aa510f8036ba2b608985.hip new file mode 100644 index 000000000000..9b994bab809e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8d79fe8a600c3b4e0ec9aa510f8036ba2b608985.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8da8285bd6182355e3164cdc5a983375cdf0a61d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8da8285bd6182355e3164cdc5a983375cdf0a61d.hip new file mode 100644 index 000000000000..61f063c38398 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8da8285bd6182355e3164cdc5a983375cdf0a61d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e1b48a28b71c7f4c78eb14321b39951a7c5e903.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e1b48a28b71c7f4c78eb14321b39951a7c5e903.hip new file mode 100644 index 000000000000..de24e5557f53 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e1b48a28b71c7f4c78eb14321b39951a7c5e903.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e2c587db8bd9f1b551624e0cf8b67a90245d7da.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e2c587db8bd9f1b551624e0cf8b67a90245d7da.hip new file mode 100644 index 000000000000..e94aa1fca6ab --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e2c587db8bd9f1b551624e0cf8b67a90245d7da.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + false, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e2d5f979fc4fbd0991581a020a414f9c8656ae2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e2d5f979fc4fbd0991581a020a414f9c8656ae2.hip new file mode 100644 index 000000000000..682bef8a5075 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e2d5f979fc4fbd0991581a020a414f9c8656ae2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e431313fe082958d31b68d2fd0d61df0fe56736.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e431313fe082958d31b68d2fd0d61df0fe56736.hip new file mode 100644 index 000000000000..25b797b2348e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e431313fe082958d31b68d2fd0d61df0fe56736.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e50ea8dd480012cbe10be392cd26d1870e6ef9b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e50ea8dd480012cbe10be392cd26d1870e6ef9b.hip new file mode 100644 index 000000000000..ed783900cfd6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e50ea8dd480012cbe10be392cd26d1870e6ef9b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e675919a6c7758cbbeecb83b7ac6c62f95cdb46.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e675919a6c7758cbbeecb83b7ac6c62f95cdb46.hip new file mode 100644 index 000000000000..2dc0c8d0afe5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e675919a6c7758cbbeecb83b7ac6c62f95cdb46.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e812705ae3e452810794fa7caceef2ef6066dfb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e812705ae3e452810794fa7caceef2ef6066dfb.hip new file mode 100644 index 000000000000..74ef40ca5306 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e812705ae3e452810794fa7caceef2ef6066dfb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e816fcad5e9ecfca94a6491eb2274bcc41e558b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e816fcad5e9ecfca94a6491eb2274bcc41e558b.hip new file mode 100644 index 000000000000..ef246224837b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e816fcad5e9ecfca94a6491eb2274bcc41e558b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e938d0e3ad30db201880642e57758285b2ec4cb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e938d0e3ad30db201880642e57758285b2ec4cb.hip new file mode 100644 index 000000000000..16488c9a130e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8e938d0e3ad30db201880642e57758285b2ec4cb.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, true, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8efb5fc2ace6839eac741c5e6616665845f43566.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8efb5fc2ace6839eac741c5e6616665845f43566.hip new file mode 100644 index 000000000000..af93b92d570f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8efb5fc2ace6839eac741c5e6616665845f43566.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8f607ee20c0d92b6dbd0338f139517fdcce98d0c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8f607ee20c0d92b6dbd0338f139517fdcce98d0c.hip new file mode 100644 index 000000000000..03cbd9201913 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8f607ee20c0d92b6dbd0338f139517fdcce98d0c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8f6e463eedd3e65b9c79feed3cd92ad8cbc9f036.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8f6e463eedd3e65b9c79feed3cd92ad8cbc9f036.hip new file mode 100644 index 000000000000..b10c84772f57 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8f6e463eedd3e65b9c79feed3cd92ad8cbc9f036.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8f7166d4bb0c1c9b9999ba16a1adbf09ebfdb6f1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8f7166d4bb0c1c9b9999ba16a1adbf09ebfdb6f1.hip new file mode 100644 index 000000000000..ee2284694134 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8f7166d4bb0c1c9b9999ba16a1adbf09ebfdb6f1.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fa4c40e244b412a07933d369704bcdaa6d5e74c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fa4c40e244b412a07933d369704bcdaa6d5e74c.hip new file mode 100644 index 000000000000..7684341efb14 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fa4c40e244b412a07933d369704bcdaa6d5e74c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fb224b40a7be7db0a9c5c08cc5ab05b526c14e8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fb224b40a7be7db0a9c5c08cc5ab05b526c14e8.hip new file mode 100644 index 000000000000..095b4f99e34a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fb224b40a7be7db0a9c5c08cc5ab05b526c14e8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fb33fc20f2e85e915f1b1529ae87981dfcaf86d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fb33fc20f2e85e915f1b1529ae87981dfcaf86d.hip new file mode 100644 index 000000000000..f866e39e0f09 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fb33fc20f2e85e915f1b1529ae87981dfcaf86d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fc08b4f3959a2375ac03f40c4ce12d70cdc2d80.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fc08b4f3959a2375ac03f40c4ce12d70cdc2d80.hip new file mode 100644 index 000000000000..1f7c65978ef0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_8fc08b4f3959a2375ac03f40c4ce12d70cdc2d80.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9009b7d39346537aa6c4a4e46b81139f603edb60.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9009b7d39346537aa6c4a4e46b81139f603edb60.hip new file mode 100644 index 000000000000..27a174e9f93e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9009b7d39346537aa6c4a4e46b81139f603edb60.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_900d7f81c73b35ea64095d01c5d48d9190839e0a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_900d7f81c73b35ea64095d01c5d48d9190839e0a.hip new file mode 100644 index 000000000000..62cad16726a5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_900d7f81c73b35ea64095d01c5d48d9190839e0a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9068ba8df8b0e977e9769f6acf6cfee6b00b9922.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9068ba8df8b0e977e9769f6acf6cfee6b00b9922.hip new file mode 100644 index 000000000000..3c572aee9c44 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9068ba8df8b0e977e9769f6acf6cfee6b00b9922.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_906fa8bf5e992ddc25815486ae9c24d8bfba7227.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_906fa8bf5e992ddc25815486ae9c24d8bfba7227.hip new file mode 100644 index 000000000000..22c7abd5ac6c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_906fa8bf5e992ddc25815486ae9c24d8bfba7227.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_90b17d8cba28cceddb3ef907df878aeef0762d15.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_90b17d8cba28cceddb3ef907df878aeef0762d15.hip new file mode 100644 index 000000000000..55938f955b03 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_90b17d8cba28cceddb3ef907df878aeef0762d15.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_90da0d469cca5c8481504148468460c85a15c559.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_90da0d469cca5c8481504148468460c85a15c559.hip new file mode 100644 index 000000000000..cbf40ea7340c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_90da0d469cca5c8481504148468460c85a15c559.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_90e5c56e92712d00092ba102a5eb5176a3e5d471.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_90e5c56e92712d00092ba102a5eb5176a3e5d471.hip new file mode 100644 index 000000000000..a51b705cf201 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_90e5c56e92712d00092ba102a5eb5176a3e5d471.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_910cb8bd09d287a1566265eb1e8894fe68d3cc81.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_910cb8bd09d287a1566265eb1e8894fe68d3cc81.hip new file mode 100644 index 000000000000..1d0fb66b1eee --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_910cb8bd09d287a1566265eb1e8894fe68d3cc81.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_915b75db795dbef037b14b003ee073665fe35d3e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_915b75db795dbef037b14b003ee073665fe35d3e.hip new file mode 100644 index 000000000000..a010fddc4904 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_915b75db795dbef037b14b003ee073665fe35d3e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9163ae070075f26926a86d39e15c27e6edb1f1cf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9163ae070075f26926a86d39e15c27e6edb1f1cf.hip new file mode 100644 index 000000000000..6c74184d0c3e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9163ae070075f26926a86d39e15c27e6edb1f1cf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91695dea4171747fb3cc6d910459f800608d07c1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91695dea4171747fb3cc6d910459f800608d07c1.hip new file mode 100644 index 000000000000..eda2b272b9e8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91695dea4171747fb3cc6d910459f800608d07c1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_919ae177b7a793fa352c4f6bb8e4175f3064d814.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_919ae177b7a793fa352c4f6bb8e4175f3064d814.hip new file mode 100644 index 000000000000..475bca57c36c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_919ae177b7a793fa352c4f6bb8e4175f3064d814.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91a6200e36944b1f11106c02f7fcee053f01ee71.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91a6200e36944b1f11106c02f7fcee053f01ee71.hip new file mode 100644 index 000000000000..f116c775fa88 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91a6200e36944b1f11106c02f7fcee053f01ee71.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91b9e2616c2fe0480096b1ccf0f74d584b220146.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91b9e2616c2fe0480096b1ccf0f74d584b220146.hip new file mode 100644 index 000000000000..972b36fe16a5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91b9e2616c2fe0480096b1ccf0f74d584b220146.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + false, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91c916e14198f6d18dc89915e379b01070434e91.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91c916e14198f6d18dc89915e379b01070434e91.hip new file mode 100644 index 000000000000..eef8d609f643 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_91c916e14198f6d18dc89915e379b01070434e91.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9207a63fc55c411c73e4f93306c5ffed800dd249.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9207a63fc55c411c73e4f93306c5ffed800dd249.hip new file mode 100644 index 000000000000..c5675d69691b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9207a63fc55c411c73e4f93306c5ffed800dd249.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92121fd448b4640a17e1a7fe73bb7b58714c0afb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92121fd448b4640a17e1a7fe73bb7b58714c0afb.hip new file mode 100644 index 000000000000..7a41e134c84b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92121fd448b4640a17e1a7fe73bb7b58714c0afb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_921f789d619db6f225e8e9d646e93bbc9dc1a669.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_921f789d619db6f225e8e9d646e93bbc9dc1a669.hip new file mode 100644 index 000000000000..ba1011e9d6c6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_921f789d619db6f225e8e9d646e93bbc9dc1a669.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92739f4464512feee083b875e11e11eee4f5b448.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92739f4464512feee083b875e11e11eee4f5b448.hip new file mode 100644 index 000000000000..5f9078186d23 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92739f4464512feee083b875e11e11eee4f5b448.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92992be6252f2afdc368bd4baec4b8a55ae0abf8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92992be6252f2afdc368bd4baec4b8a55ae0abf8.hip new file mode 100644 index 000000000000..7ab5d7d17057 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92992be6252f2afdc368bd4baec4b8a55ae0abf8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92b0770fe64e3c60b9e56170aa88bbf74802a813.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92b0770fe64e3c60b9e56170aa88bbf74802a813.hip new file mode 100644 index 000000000000..ced9c424f09b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92b0770fe64e3c60b9e56170aa88bbf74802a813.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92b722cdabcfaa388ccc6ccceb7e42462f3bdcd1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92b722cdabcfaa388ccc6ccceb7e42462f3bdcd1.hip new file mode 100644 index 000000000000..7b63fe2ccfb4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92b722cdabcfaa388ccc6ccceb7e42462f3bdcd1.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92ba64cdf615c1be2865f027a293cb530fc07dc6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92ba64cdf615c1be2865f027a293cb530fc07dc6.hip new file mode 100644 index 000000000000..4440b7e6504b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92ba64cdf615c1be2865f027a293cb530fc07dc6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92d841e6d783bb46d841aafd9027f92dd1b61b88.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92d841e6d783bb46d841aafd9027f92dd1b61b88.hip new file mode 100644 index 000000000000..1e13fa7afc06 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92d841e6d783bb46d841aafd9027f92dd1b61b88.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92e53359c69bbe4d7405d45261a8a62008eb7d06.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92e53359c69bbe4d7405d45261a8a62008eb7d06.hip new file mode 100644 index 000000000000..d2ae7ae45fbe --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92e53359c69bbe4d7405d45261a8a62008eb7d06.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92f9ad0fb65638cfffb3e7786f2cbf01d9585b23.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92f9ad0fb65638cfffb3e7786f2cbf01d9585b23.hip new file mode 100644 index 000000000000..908e7ff0b5bd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_92f9ad0fb65638cfffb3e7786f2cbf01d9585b23.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_93054acb8a9508fd0f0f486367fb62454de47c39.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_93054acb8a9508fd0f0f486367fb62454de47c39.hip new file mode 100644 index 000000000000..4501cb318fa3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_93054acb8a9508fd0f0f486367fb62454de47c39.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_931cf8d05cfa45319f4e5bb49334d35a530bffcf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_931cf8d05cfa45319f4e5bb49334d35a530bffcf.hip new file mode 100644 index 000000000000..6da3c2e6000c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_931cf8d05cfa45319f4e5bb49334d35a530bffcf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_93728d999ae43ee1b5a16e60b90cf8533c7d303f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_93728d999ae43ee1b5a16e60b90cf8533c7d303f.hip new file mode 100644 index 000000000000..0e4e84ffe0b8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_93728d999ae43ee1b5a16e60b90cf8533c7d303f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_937801fbb43fb6797f0425f08d13926b74d87c4a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_937801fbb43fb6797f0425f08d13926b74d87c4a.hip new file mode 100644 index 000000000000..1c77f540b689 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_937801fbb43fb6797f0425f08d13926b74d87c4a.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_937c48d0b7096ad6c8bc445f13f2c8c1934695ab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_937c48d0b7096ad6c8bc445f13f2c8c1934695ab.hip new file mode 100644 index 000000000000..814ad8ec592f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_937c48d0b7096ad6c8bc445f13f2c8c1934695ab.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_93b885d6869400b0dc2ef1b2c2636ddfd21cde31.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_93b885d6869400b0dc2ef1b2c2636ddfd21cde31.hip new file mode 100644 index 000000000000..923593adb171 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_93b885d6869400b0dc2ef1b2c2636ddfd21cde31.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, false, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_942439e4f5644a3a4630481bc7d98834b29b6e1c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_942439e4f5644a3a4630481bc7d98834b29b6e1c.hip new file mode 100644 index 000000000000..dd4fd8540922 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_942439e4f5644a3a4630481bc7d98834b29b6e1c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_94a94d145e575747c8956ac703810582c819e2e8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_94a94d145e575747c8956ac703810582c819e2e8.hip new file mode 100644 index 000000000000..734603126baf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_94a94d145e575747c8956ac703810582c819e2e8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_94aa519eb57e5797125728492d9330f5c0f0670a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_94aa519eb57e5797125728492d9330f5c0f0670a.hip new file mode 100644 index 000000000000..873f1b44dede --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_94aa519eb57e5797125728492d9330f5c0f0670a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_94f6f9dee9f0c3825d91f4d320a5280070e60ee7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_94f6f9dee9f0c3825d91f4d320a5280070e60ee7.hip new file mode 100644 index 000000000000..026413f14c76 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_94f6f9dee9f0c3825d91f4d320a5280070e60ee7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_95061acc6650fc7b79fa1fe5b2b1e083555eec2c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_95061acc6650fc7b79fa1fe5b2b1e083555eec2c.hip new file mode 100644 index 000000000000..556c9bd087b8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_95061acc6650fc7b79fa1fe5b2b1e083555eec2c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_951343832a5bfd060c8d12da0d8a090f070a717d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_951343832a5bfd060c8d12da0d8a090f070a717d.hip new file mode 100644 index 000000000000..759e5ee3190d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_951343832a5bfd060c8d12da0d8a090f070a717d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9545f95c1093c60f0fb6c794636f79aaeb53b733.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9545f95c1093c60f0fb6c794636f79aaeb53b733.hip new file mode 100644 index 000000000000..9038cd3adf4a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9545f95c1093c60f0fb6c794636f79aaeb53b733.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_95530399ad7b43d8ce2c89da24c71056f2146b18.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_95530399ad7b43d8ce2c89da24c71056f2146b18.hip new file mode 100644 index 000000000000..dbc2e060f86e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_95530399ad7b43d8ce2c89da24c71056f2146b18.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9583148fd684a7e6a312127e023798278415bd27.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9583148fd684a7e6a312127e023798278415bd27.hip new file mode 100644 index 000000000000..dcc79ee6a0b9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9583148fd684a7e6a312127e023798278415bd27.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9594816877815bc0294610ca24f986fdccdc7c6f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9594816877815bc0294610ca24f986fdccdc7c6f.hip new file mode 100644 index 000000000000..ca040732924d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9594816877815bc0294610ca24f986fdccdc7c6f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_960ecb3013071fb65f2d5ed4c947c4bf303e5308.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_960ecb3013071fb65f2d5ed4c947c4bf303e5308.hip new file mode 100644 index 000000000000..786928b952cd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_960ecb3013071fb65f2d5ed4c947c4bf303e5308.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9638c9618dbf2af119e37596f7eb0fd3f8d72748.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9638c9618dbf2af119e37596f7eb0fd3f8d72748.hip new file mode 100644 index 000000000000..4f7978f8ab6f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9638c9618dbf2af119e37596f7eb0fd3f8d72748.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_963986150adcd6e1d3886bacf2166de1252e14df.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_963986150adcd6e1d3886bacf2166de1252e14df.hip new file mode 100644 index 000000000000..cc7bb4ca2c1e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_963986150adcd6e1d3886bacf2166de1252e14df.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + false, + false, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_964f916d3484295b5918e2e4c22c5529588a5662.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_964f916d3484295b5918e2e4c22c5529588a5662.hip new file mode 100644 index 000000000000..2021cb3b9482 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_964f916d3484295b5918e2e4c22c5529588a5662.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9689ecd7bf51bcffe9f5002959bdda41c50a3c8b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9689ecd7bf51bcffe9f5002959bdda41c50a3c8b.hip new file mode 100644 index 000000000000..ef65aa0cbe92 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9689ecd7bf51bcffe9f5002959bdda41c50a3c8b.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + false, + false, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_968fc75a7d102aca068e3ceb6111728c280fa837.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_968fc75a7d102aca068e3ceb6111728c280fa837.hip new file mode 100644 index 000000000000..50a0355bf06d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_968fc75a7d102aca068e3ceb6111728c280fa837.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96c129dd4c798343d6f78ab78056f0faf2f1c9d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96c129dd4c798343d6f78ab78056f0faf2f1c9d3.hip new file mode 100644 index 000000000000..8594f3173431 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96c129dd4c798343d6f78ab78056f0faf2f1c9d3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96c5e79f54b71677124f555b0ae4bfd27248d099.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96c5e79f54b71677124f555b0ae4bfd27248d099.hip new file mode 100644 index 000000000000..21bb04489710 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96c5e79f54b71677124f555b0ae4bfd27248d099.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96caa2056d99eb67ada498e287b4fae984397691.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96caa2056d99eb67ada498e287b4fae984397691.hip new file mode 100644 index 000000000000..cb1ec3c5f7e6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96caa2056d99eb67ada498e287b4fae984397691.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96dee49ec6755006d67f0c30c65f50558bba69b0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96dee49ec6755006d67f0c30c65f50558bba69b0.hip new file mode 100644 index 000000000000..1b1bf13c2909 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96dee49ec6755006d67f0c30c65f50558bba69b0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96f1bb85dff8c97846f6b2e8796a6289bcd0d9d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96f1bb85dff8c97846f6b2e8796a6289bcd0d9d3.hip new file mode 100644 index 000000000000..d32f58de2c3c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_96f1bb85dff8c97846f6b2e8796a6289bcd0d9d3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_970073c70133ff2ee4737f803a0ac43801c47242.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_970073c70133ff2ee4737f803a0ac43801c47242.hip new file mode 100644 index 000000000000..ce373394af37 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_970073c70133ff2ee4737f803a0ac43801c47242.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_971a08c2e48d805b295d979b24173a04cf58def0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_971a08c2e48d805b295d979b24173a04cf58def0.hip new file mode 100644 index 000000000000..1bb18015aa15 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_971a08c2e48d805b295d979b24173a04cf58def0.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_97246460c21bc66c0f13936d27477a9fca1c44d1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_97246460c21bc66c0f13936d27477a9fca1c44d1.hip new file mode 100644 index 000000000000..ab8f2aeef41d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_97246460c21bc66c0f13936d27477a9fca1c44d1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9745b04a8026a01828c5dd606d89d044d3ed1d99.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9745b04a8026a01828c5dd606d89d044d3ed1d99.hip new file mode 100644 index 000000000000..2c75f779939c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9745b04a8026a01828c5dd606d89d044d3ed1d99.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + true, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_976cf509d9c2bf86ba6ee5ded544fa8e6717f590.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_976cf509d9c2bf86ba6ee5ded544fa8e6717f590.hip new file mode 100644 index 000000000000..063c77a45493 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_976cf509d9c2bf86ba6ee5ded544fa8e6717f590.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_977137b371df841993c8d0584be7d83aca6add78.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_977137b371df841993c8d0584be7d83aca6add78.hip new file mode 100644 index 000000000000..152d98ac85d7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_977137b371df841993c8d0584be7d83aca6add78.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_97851d5ecbf02f8af623988b1a39c0b91e51533a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_97851d5ecbf02f8af623988b1a39c0b91e51533a.hip new file mode 100644 index 000000000000..63e59aa9da0b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_97851d5ecbf02f8af623988b1a39c0b91e51533a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9801b25e0f132d647934deb395b62a3f70cc7c88.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9801b25e0f132d647934deb395b62a3f70cc7c88.hip new file mode 100644 index 000000000000..8c6225629062 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9801b25e0f132d647934deb395b62a3f70cc7c88.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_987a617fae00fa90a1ba60937b0312c81087c19e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_987a617fae00fa90a1ba60937b0312c81087c19e.hip new file mode 100644 index 000000000000..78a4ab4e6e47 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_987a617fae00fa90a1ba60937b0312c81087c19e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_987f00dd759d9714693e7517dfaa8bb427294d42.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_987f00dd759d9714693e7517dfaa8bb427294d42.hip new file mode 100644 index 000000000000..4fadf647940e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_987f00dd759d9714693e7517dfaa8bb427294d42.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9893336a4b00b2a63f23ed7e13ec54c82d9e5063.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9893336a4b00b2a63f23ed7e13ec54c82d9e5063.hip new file mode 100644 index 000000000000..43bcd2199617 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9893336a4b00b2a63f23ed7e13ec54c82d9e5063.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + true, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_98e484adeddf3394d8d7693b808d83b64c71ee69.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_98e484adeddf3394d8d7693b808d83b64c71ee69.hip new file mode 100644 index 000000000000..704c87b0a9b1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_98e484adeddf3394d8d7693b808d83b64c71ee69.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_98f5efcd500ce6b9ffc14bc9877e0ba457539925.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_98f5efcd500ce6b9ffc14bc9877e0ba457539925.hip new file mode 100644 index 000000000000..3ecabacb7f24 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_98f5efcd500ce6b9ffc14bc9877e0ba457539925.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_98f9a4f4d85f292b78123599a2e1798f12aa545b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_98f9a4f4d85f292b78123599a2e1798f12aa545b.hip new file mode 100644 index 000000000000..65076c99dd4f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_98f9a4f4d85f292b78123599a2e1798f12aa545b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9990e6ad243a48b84304b5cad0c663c0802aedfd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9990e6ad243a48b84304b5cad0c663c0802aedfd.hip new file mode 100644 index 000000000000..bb635d6ff822 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9990e6ad243a48b84304b5cad0c663c0802aedfd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_99ae680eed89ea93a3a94586bd5a68dbc5439f37.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_99ae680eed89ea93a3a94586bd5a68dbc5439f37.hip new file mode 100644 index 000000000000..f2c7385f46d2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_99ae680eed89ea93a3a94586bd5a68dbc5439f37.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_99e2f290b962f1617b0a9d4fd6d55c43e4439d6f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_99e2f290b962f1617b0a9d4fd6d55c43e4439d6f.hip new file mode 100644 index 000000000000..1d331ff19726 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_99e2f290b962f1617b0a9d4fd6d55c43e4439d6f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_99f8352674bd6bbe98944a1c0a769a4fc028a623.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_99f8352674bd6bbe98944a1c0a769a4fc028a623.hip new file mode 100644 index 000000000000..f101ae9fe349 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_99f8352674bd6bbe98944a1c0a769a4fc028a623.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a0a70932bd587759df1e5e150b25b0126d7b529.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a0a70932bd587759df1e5e150b25b0126d7b529.hip new file mode 100644 index 000000000000..a8b9243487d5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a0a70932bd587759df1e5e150b25b0126d7b529.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a20fa19d8d30654602e363806f559113218d66d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a20fa19d8d30654602e363806f559113218d66d.hip new file mode 100644 index 000000000000..53302ce3d974 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a20fa19d8d30654602e363806f559113218d66d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a8e04fe9432a60f86ff0369e8c1851821074a04.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a8e04fe9432a60f86ff0369e8c1851821074a04.hip new file mode 100644 index 000000000000..b3d0ad587cfe --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a8e04fe9432a60f86ff0369e8c1851821074a04.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a9edbe35a8fac7796f00bde836bd547044770ea.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a9edbe35a8fac7796f00bde836bd547044770ea.hip new file mode 100644 index 000000000000..79ab1cd43cf3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9a9edbe35a8fac7796f00bde836bd547044770ea.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ab73ea77ec20ea3bfaf995dacf93a6960ecdca0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ab73ea77ec20ea3bfaf995dacf93a6960ecdca0.hip new file mode 100644 index 000000000000..ca4f0e8d462e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ab73ea77ec20ea3bfaf995dacf93a6960ecdca0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ad1f99284aafc8d7908d062f179a056eb314925.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ad1f99284aafc8d7908d062f179a056eb314925.hip new file mode 100644 index 000000000000..57ddba224878 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ad1f99284aafc8d7908d062f179a056eb314925.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ae866c7db36286876818bfb718ac35204fa3843.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ae866c7db36286876818bfb718ac35204fa3843.hip new file mode 100644 index 000000000000..e177ef281b7f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ae866c7db36286876818bfb718ac35204fa3843.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9afe4b6f3b901ff4af81bd4f1cd8ff19f09d0b07.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9afe4b6f3b901ff4af81bd4f1cd8ff19f09d0b07.hip new file mode 100644 index 000000000000..e23766791033 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9afe4b6f3b901ff4af81bd4f1cd8ff19f09d0b07.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b062dd633645772e4f2caffd111af73184f7657.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b062dd633645772e4f2caffd111af73184f7657.hip new file mode 100644 index 000000000000..8fe8fbf9b2b6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b062dd633645772e4f2caffd111af73184f7657.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b327f0fa1155f2235d76be45cd22e3db5a69429.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b327f0fa1155f2235d76be45cd22e3db5a69429.hip new file mode 100644 index 000000000000..63a0c899b89f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b327f0fa1155f2235d76be45cd22e3db5a69429.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b4dcde1ae3446b825dea739d4295c1d1ec5c4be.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b4dcde1ae3446b825dea739d4295c1d1ec5c4be.hip new file mode 100644 index 000000000000..263054dca17f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b4dcde1ae3446b825dea739d4295c1d1ec5c4be.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b6d08e63b9a90f2524cbfa8c5fcf8b82a1d2d36.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b6d08e63b9a90f2524cbfa8c5fcf8b82a1d2d36.hip new file mode 100644 index 000000000000..3c95ad357d9f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b6d08e63b9a90f2524cbfa8c5fcf8b82a1d2d36.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b73c92a13757877f34bd8a13c6fb29b60999020.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b73c92a13757877f34bd8a13c6fb29b60999020.hip new file mode 100644 index 000000000000..c12ba4da6d44 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b73c92a13757877f34bd8a13c6fb29b60999020.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b841b7cf5da31f0c30ec42c91cc8d5bd3fedd03.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b841b7cf5da31f0c30ec42c91cc8d5bd3fedd03.hip new file mode 100644 index 000000000000..82d78dccfa0b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9b841b7cf5da31f0c30ec42c91cc8d5bd3fedd03.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9bcc791049e3ff9ebc1a9085d2d20efcc2f99b71.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9bcc791049e3ff9ebc1a9085d2d20efcc2f99b71.hip new file mode 100644 index 000000000000..a3cda6896469 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9bcc791049e3ff9ebc1a9085d2d20efcc2f99b71.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9bf235679af1ca03a6e601b4cf6cd0416d1c9091.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9bf235679af1ca03a6e601b4cf6cd0416d1c9091.hip new file mode 100644 index 000000000000..b2b9004457bf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9bf235679af1ca03a6e601b4cf6cd0416d1c9091.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9c4fc7cda4b560040cec93f63021b529aa1ee3fd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9c4fc7cda4b560040cec93f63021b529aa1ee3fd.hip new file mode 100644 index 000000000000..401c9980bc09 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9c4fc7cda4b560040cec93f63021b529aa1ee3fd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ca3b1d36d777213eb381b47871bf15dd163c994.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ca3b1d36d777213eb381b47871bf15dd163c994.hip new file mode 100644 index 000000000000..288085b7f03e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9ca3b1d36d777213eb381b47871bf15dd163c994.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9cc3ef3d3b36f52089548e9dce522b0448e2c26a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9cc3ef3d3b36f52089548e9dce522b0448e2c26a.hip new file mode 100644 index 000000000000..d93a05b6a55e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9cc3ef3d3b36f52089548e9dce522b0448e2c26a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9d3d274058bc0a3d4d35d90669587761fdfbdba1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9d3d274058bc0a3d4d35d90669587761fdfbdba1.hip new file mode 100644 index 000000000000..b154cf2e3a21 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9d3d274058bc0a3d4d35d90669587761fdfbdba1.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9d6759d8855c4c6289f1f241a1628cf0406c1b64.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9d6759d8855c4c6289f1f241a1628cf0406c1b64.hip new file mode 100644 index 000000000000..5eb4f26ae7bf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9d6759d8855c4c6289f1f241a1628cf0406c1b64.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9d69d441f48f9ea346dd8e00376a9a708da3ad87.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9d69d441f48f9ea346dd8e00376a9a708da3ad87.hip new file mode 100644 index 000000000000..13506066037d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9d69d441f48f9ea346dd8e00376a9a708da3ad87.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9dc424f0e192155e3c4e786e5b87d5a1a3e6c4ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9dc424f0e192155e3c4e786e5b87d5a1a3e6c4ad.hip new file mode 100644 index 000000000000..da1632ff5dee --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9dc424f0e192155e3c4e786e5b87d5a1a3e6c4ad.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9e51083e13aa4dfa8c969f8f916835a8e5e9ca39.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9e51083e13aa4dfa8c969f8f916835a8e5e9ca39.hip new file mode 100644 index 000000000000..5a974dfe44d6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9e51083e13aa4dfa8c969f8f916835a8e5e9ca39.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9eef1b54d5d3841f3fa6b84cca6c7ad33efa2d9f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9eef1b54d5d3841f3fa6b84cca6c7ad33efa2d9f.hip new file mode 100644 index 000000000000..7217aff87c7f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9eef1b54d5d3841f3fa6b84cca6c7ad33efa2d9f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9f0517550c7a23882b95de451e8099ea2186b4ce.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9f0517550c7a23882b95de451e8099ea2186b4ce.hip new file mode 100644 index 000000000000..49834b64050e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9f0517550c7a23882b95de451e8099ea2186b4ce.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9fb389d4b5ba590baa951f17da06f0e53d2bfa55.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9fb389d4b5ba590baa951f17da06f0e53d2bfa55.hip new file mode 100644 index 000000000000..1106b708d9fa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_9fb389d4b5ba590baa951f17da06f0e53d2bfa55.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a017be7b8bcf303b30a147f41346898acc5fab7d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a017be7b8bcf303b30a147f41346898acc5fab7d.hip new file mode 100644 index 000000000000..29a53b801661 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a017be7b8bcf303b30a147f41346898acc5fab7d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a02a71fdd587e47ee68e0cc76c3c4494ce06c359.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a02a71fdd587e47ee68e0cc76c3c4494ce06c359.hip new file mode 100644 index 000000000000..5c76e72928be --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a02a71fdd587e47ee68e0cc76c3c4494ce06c359.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a02f152e9184af0b3d77082d8bdf519dbbfceb2d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a02f152e9184af0b3d77082d8bdf519dbbfceb2d.hip new file mode 100644 index 000000000000..9f041ce8f562 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a02f152e9184af0b3d77082d8bdf519dbbfceb2d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a046e888e3836b0bd3c49fec8e1872e880798f0c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a046e888e3836b0bd3c49fec8e1872e880798f0c.hip new file mode 100644 index 000000000000..de2940480d73 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a046e888e3836b0bd3c49fec8e1872e880798f0c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a0874fc5ac87a1ec487c7722bf3b1bdaa924ee09.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a0874fc5ac87a1ec487c7722bf3b1bdaa924ee09.hip new file mode 100644 index 000000000000..53aaa3daba99 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a0874fc5ac87a1ec487c7722bf3b1bdaa924ee09.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a094599fb5caf5e7aba728cd4713a8d0c6368a46.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a094599fb5caf5e7aba728cd4713a8d0c6368a46.hip new file mode 100644 index 000000000000..553876b0ac34 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a094599fb5caf5e7aba728cd4713a8d0c6368a46.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a0a556c9358ddd6db719458c81d2d6d822a895da.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a0a556c9358ddd6db719458c81d2d6d822a895da.hip new file mode 100644 index 000000000000..fca384536e56 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a0a556c9358ddd6db719458c81d2d6d822a895da.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a103cd47156a98ad2cf2c325ea00df3f1d67fb72.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a103cd47156a98ad2cf2c325ea00df3f1d67fb72.hip new file mode 100644 index 000000000000..83f54dcbfa2d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a103cd47156a98ad2cf2c325ea00df3f1d67fb72.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a189292c81a18d21a2921ce6740f81ebf4c046ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a189292c81a18d21a2921ce6740f81ebf4c046ad.hip new file mode 100644 index 000000000000..e2ead7d8b1e4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a189292c81a18d21a2921ce6740f81ebf4c046ad.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a1c71e7d33f0597fe090a3524e33e18b2e562680.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a1c71e7d33f0597fe090a3524e33e18b2e562680.hip new file mode 100644 index 000000000000..1d2ace1439bd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a1c71e7d33f0597fe090a3524e33e18b2e562680.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a1cba1509c413c870c5d784410855ee1bd737da2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a1cba1509c413c870c5d784410855ee1bd737da2.hip new file mode 100644 index 000000000000..d40fb018bb2f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a1cba1509c413c870c5d784410855ee1bd737da2.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a1d6ad9de7ac7993ae1923a2ef070b7dacb8c563.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a1d6ad9de7ac7993ae1923a2ef070b7dacb8c563.hip new file mode 100644 index 000000000000..19561dc031b1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a1d6ad9de7ac7993ae1923a2ef070b7dacb8c563.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a20c91b2f11bb7e5058ca7935b0bda4f5558a9dc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a20c91b2f11bb7e5058ca7935b0bda4f5558a9dc.hip new file mode 100644 index 000000000000..6f2730e273fa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a20c91b2f11bb7e5058ca7935b0bda4f5558a9dc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a21f3637624762547af1292e1b85e640b1d329dc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a21f3637624762547af1292e1b85e640b1d329dc.hip new file mode 100644 index 000000000000..2153f8dcc77e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a21f3637624762547af1292e1b85e640b1d329dc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a225c4f1f3c7b271957768bb9235131c67afb48a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a225c4f1f3c7b271957768bb9235131c67afb48a.hip new file mode 100644 index 000000000000..ade8b04d9eef --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a225c4f1f3c7b271957768bb9235131c67afb48a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a2482a64659c838f3da55f56e3cbbee1dbfe6722.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a2482a64659c838f3da55f56e3cbbee1dbfe6722.hip new file mode 100644 index 000000000000..2114626fba4d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a2482a64659c838f3da55f56e3cbbee1dbfe6722.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a25e2aed617e1ff31f93ae7e054313ee0dceee97.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a25e2aed617e1ff31f93ae7e054313ee0dceee97.hip new file mode 100644 index 000000000000..feda9895fa65 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a25e2aed617e1ff31f93ae7e054313ee0dceee97.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a2a715b7e9c1a576f011dfe5769c5b392e984f82.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a2a715b7e9c1a576f011dfe5769c5b392e984f82.hip new file mode 100644 index 000000000000..f940a11d4eb2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a2a715b7e9c1a576f011dfe5769c5b392e984f82.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a2ef5d30a2318ae06430d17f84878800c4ca7364.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a2ef5d30a2318ae06430d17f84878800c4ca7364.hip new file mode 100644 index 000000000000..ab1e7651e61b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a2ef5d30a2318ae06430d17f84878800c4ca7364.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3339150d8bf9d073827738527f6cbe15b854607.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3339150d8bf9d073827738527f6cbe15b854607.hip new file mode 100644 index 000000000000..866d9f283151 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3339150d8bf9d073827738527f6cbe15b854607.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3709e4fc53d2254a03ea7660b8c72d2f47cf1ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3709e4fc53d2254a03ea7660b8c72d2f47cf1ad.hip new file mode 100644 index 000000000000..cd987b349523 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3709e4fc53d2254a03ea7660b8c72d2f47cf1ad.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a388a284f45f711d82a6ed87036d87cef1872eb1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a388a284f45f711d82a6ed87036d87cef1872eb1.hip new file mode 100644 index 000000000000..74201ad912ef --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a388a284f45f711d82a6ed87036d87cef1872eb1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3ac4f93722dc314086f1b7d7b8adc687cd75f82.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3ac4f93722dc314086f1b7d7b8adc687cd75f82.hip new file mode 100644 index 000000000000..2641e736fc7e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3ac4f93722dc314086f1b7d7b8adc687cd75f82.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + false, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3d7aa46528ee74e2bef1e87c1feceacfa55e173.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3d7aa46528ee74e2bef1e87c1feceacfa55e173.hip new file mode 100644 index 000000000000..b78f6fbd44b0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3d7aa46528ee74e2bef1e87c1feceacfa55e173.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3dc780b17152f696f9b957432c2eae8fb16e85e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3dc780b17152f696f9b957432c2eae8fb16e85e.hip new file mode 100644 index 000000000000..7974289e9b6c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3dc780b17152f696f9b957432c2eae8fb16e85e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3f9c236d24b30bc9c3fad90cfd6eb00da835de2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3f9c236d24b30bc9c3fad90cfd6eb00da835de2.hip new file mode 100644 index 000000000000..3eebf99d1a5a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3f9c236d24b30bc9c3fad90cfd6eb00da835de2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3ff8445ba691807caadd9f26e7eb90851875280.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3ff8445ba691807caadd9f26e7eb90851875280.hip new file mode 100644 index 000000000000..cb73d31ebaf6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a3ff8445ba691807caadd9f26e7eb90851875280.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a421c2ed6b295c458071f1988b9d6f7b46e8992c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a421c2ed6b295c458071f1988b9d6f7b46e8992c.hip new file mode 100644 index 000000000000..f44a100c7dc3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a421c2ed6b295c458071f1988b9d6f7b46e8992c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a4700d87a19a173e84d64e43cffabbed52366e35.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a4700d87a19a173e84d64e43cffabbed52366e35.hip new file mode 100644 index 000000000000..6bd087d8c111 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a4700d87a19a173e84d64e43cffabbed52366e35.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a487f617c4b84c6a0328fedac750d41dc3dafe27.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a487f617c4b84c6a0328fedac750d41dc3dafe27.hip new file mode 100644 index 000000000000..601f2e7e5c69 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a487f617c4b84c6a0328fedac750d41dc3dafe27.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a48843d844f78690c7a45b730652f0f763c595c7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a48843d844f78690c7a45b730652f0f763c595c7.hip new file mode 100644 index 000000000000..77c01f585a06 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a48843d844f78690c7a45b730652f0f763c595c7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a4980becb0d3149fee575bad1fc3b463d08aabf5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a4980becb0d3149fee575bad1fc3b463d08aabf5.hip new file mode 100644 index 000000000000..f0add1a5b8bc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a4980becb0d3149fee575bad1fc3b463d08aabf5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a4b7f10440331a8a88ff93ba253217c2832bcf9e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a4b7f10440331a8a88ff93ba253217c2832bcf9e.hip new file mode 100644 index 000000000000..e45d80c6a565 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a4b7f10440331a8a88ff93ba253217c2832bcf9e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a55b47aafc4340e69e300ac61a7601a5c14513b7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a55b47aafc4340e69e300ac61a7601a5c14513b7.hip new file mode 100644 index 000000000000..e062a838409a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a55b47aafc4340e69e300ac61a7601a5c14513b7.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a55c7dd576e5b1061c059e5e99aeedf4389e2d25.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a55c7dd576e5b1061c059e5e99aeedf4389e2d25.hip new file mode 100644 index 000000000000..98a8d1282077 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a55c7dd576e5b1061c059e5e99aeedf4389e2d25.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a59423c095db052603d77073d409534bceef425f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a59423c095db052603d77073d409534bceef425f.hip new file mode 100644 index 000000000000..78329df0c5cd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a59423c095db052603d77073d409534bceef425f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5a7833f4597bb03a3e845d5580d677e97421040.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5a7833f4597bb03a3e845d5580d677e97421040.hip new file mode 100644 index 000000000000..bba679dadf61 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5a7833f4597bb03a3e845d5580d677e97421040.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5bdc110955c05c6c6ea236a6f60266a4a6dce5e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5bdc110955c05c6c6ea236a6f60266a4a6dce5e.hip new file mode 100644 index 000000000000..efde34b2f2ce --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5bdc110955c05c6c6ea236a6f60266a4a6dce5e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5c0109313de1f6245d2a80f8539485b849e9d55.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5c0109313de1f6245d2a80f8539485b849e9d55.hip new file mode 100644 index 000000000000..31f86f982f04 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5c0109313de1f6245d2a80f8539485b849e9d55.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5c4dc0d70c547dbbfb661e879ba7f9adfafc2ea.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5c4dc0d70c547dbbfb661e879ba7f9adfafc2ea.hip new file mode 100644 index 000000000000..71649189b5de --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5c4dc0d70c547dbbfb661e879ba7f9adfafc2ea.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5d4eb673bafd81e3a0ee213da4603d88b8460ec.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5d4eb673bafd81e3a0ee213da4603d88b8460ec.hip new file mode 100644 index 000000000000..dd74f1bd6362 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5d4eb673bafd81e3a0ee213da4603d88b8460ec.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5e5cae764142683b70d3344cf07dd1edb7d69e2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5e5cae764142683b70d3344cf07dd1edb7d69e2.hip new file mode 100644 index 000000000000..73465e5777d8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5e5cae764142683b70d3344cf07dd1edb7d69e2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5f2f0cef657ae5e333d65ae4ab20529a43cd7de.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5f2f0cef657ae5e333d65ae4ab20529a43cd7de.hip new file mode 100644 index 000000000000..7b05991667c3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5f2f0cef657ae5e333d65ae4ab20529a43cd7de.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5f8b7b2a891aa9f2ab49762eb31d835efdf18b6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5f8b7b2a891aa9f2ab49762eb31d835efdf18b6.hip new file mode 100644 index 000000000000..e1e538835e19 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5f8b7b2a891aa9f2ab49762eb31d835efdf18b6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5fa94bb32a80e81886b711ebfcf2df5f5405866.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5fa94bb32a80e81886b711ebfcf2df5f5405866.hip new file mode 100644 index 000000000000..2ab79463c5b4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a5fa94bb32a80e81886b711ebfcf2df5f5405866.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a622fa57764ec746e02f6d4bd4846b48c722b807.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a622fa57764ec746e02f6d4bd4846b48c722b807.hip new file mode 100644 index 000000000000..46e0c94e86b9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a622fa57764ec746e02f6d4bd4846b48c722b807.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a62a2ab489839ea1a1bfd1b24e54a3c232ed934f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a62a2ab489839ea1a1bfd1b24e54a3c232ed934f.hip new file mode 100644 index 000000000000..17d1bb9530e1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a62a2ab489839ea1a1bfd1b24e54a3c232ed934f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a6461d72fb6ba50e81de3f661528c96dcfdc3f3c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a6461d72fb6ba50e81de3f661528c96dcfdc3f3c.hip new file mode 100644 index 000000000000..44eb1986a666 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a6461d72fb6ba50e81de3f661528c96dcfdc3f3c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a64b4cf3f6706e4b4e0af4402e2263b9a1585f9b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a64b4cf3f6706e4b4e0af4402e2263b9a1585f9b.hip new file mode 100644 index 000000000000..e793e94adb6d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a64b4cf3f6706e4b4e0af4402e2263b9a1585f9b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a65c43b870705c780d734f9ef063f55cf8b3b52d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a65c43b870705c780d734f9ef063f55cf8b3b52d.hip new file mode 100644 index 000000000000..3a3b4e69a888 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a65c43b870705c780d734f9ef063f55cf8b3b52d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a673f35edd69241c6b921d6712dfd064d78ecbad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a673f35edd69241c6b921d6712dfd064d78ecbad.hip new file mode 100644 index 000000000000..bf2296db5008 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a673f35edd69241c6b921d6712dfd064d78ecbad.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a71305f191f06cd53b7563971c706e8b71b19e2f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a71305f191f06cd53b7563971c706e8b71b19e2f.hip new file mode 100644 index 000000000000..0642dace3483 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a71305f191f06cd53b7563971c706e8b71b19e2f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a74b0e7dd816ad08eec5a1bba6e227afee9813ec.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a74b0e7dd816ad08eec5a1bba6e227afee9813ec.hip new file mode 100644 index 000000000000..4570446797c5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a74b0e7dd816ad08eec5a1bba6e227afee9813ec.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a7784b03ad757d51c234fa86ea9891f055ecd5c1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a7784b03ad757d51c234fa86ea9891f055ecd5c1.hip new file mode 100644 index 000000000000..7ecbad53f903 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a7784b03ad757d51c234fa86ea9891f055ecd5c1.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + true, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a78fecb9725ceb4bcf2aa037d43bc43efeb1c3fd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a78fecb9725ceb4bcf2aa037d43bc43efeb1c3fd.hip new file mode 100644 index 000000000000..09b7dce18845 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a78fecb9725ceb4bcf2aa037d43bc43efeb1c3fd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a7f7553a7d2f6d42fe695cdc64423c85223af440.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a7f7553a7d2f6d42fe695cdc64423c85223af440.hip new file mode 100644 index 000000000000..e98471e8dc28 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a7f7553a7d2f6d42fe695cdc64423c85223af440.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a821661d8280c6e9d27f2c9ce1b3c855387b5a76.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a821661d8280c6e9d27f2c9ce1b3c855387b5a76.hip new file mode 100644 index 000000000000..5ff5eb5428e1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a821661d8280c6e9d27f2c9ce1b3c855387b5a76.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a85d35b2fd98742427930eb536e346ffb005edd8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a85d35b2fd98742427930eb536e346ffb005edd8.hip new file mode 100644 index 000000000000..2df4bb295235 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a85d35b2fd98742427930eb536e346ffb005edd8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a8a4af070ee46d802cb11086b93daf91538f8a04.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a8a4af070ee46d802cb11086b93daf91538f8a04.hip new file mode 100644 index 000000000000..30c74630350b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a8a4af070ee46d802cb11086b93daf91538f8a04.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a8a744edfa3a19d1493611df5bd0d4d59b707d43.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a8a744edfa3a19d1493611df5bd0d4d59b707d43.hip new file mode 100644 index 000000000000..4d052ab3db83 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a8a744edfa3a19d1493611df5bd0d4d59b707d43.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a92b43d374642df991edef1f6036dc898bf77cf8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a92b43d374642df991edef1f6036dc898bf77cf8.hip new file mode 100644 index 000000000000..10985b37d798 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a92b43d374642df991edef1f6036dc898bf77cf8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a93324ccf11b273ed20fd960c61df897c8890b1d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a93324ccf11b273ed20fd960c61df897c8890b1d.hip new file mode 100644 index 000000000000..489e0d778e22 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a93324ccf11b273ed20fd960c61df897c8890b1d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a93a03b33305b33055273711ab31a5b8d8298d5d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a93a03b33305b33055273711ab31a5b8d8298d5d.hip new file mode 100644 index 000000000000..3e858ad28353 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a93a03b33305b33055273711ab31a5b8d8298d5d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a968df29f5ae1463706b7981b3bde55918e1aa65.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a968df29f5ae1463706b7981b3bde55918e1aa65.hip new file mode 100644 index 000000000000..c9c1b3813247 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a968df29f5ae1463706b7981b3bde55918e1aa65.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a98925d99dc484da41dd55700e151cf545cf821d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a98925d99dc484da41dd55700e151cf545cf821d.hip new file mode 100644 index 000000000000..82d7bcf2cc1f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a98925d99dc484da41dd55700e151cf545cf821d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9b50c6ebb27986ce5b378d8c39315eb9cb91dea.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9b50c6ebb27986ce5b378d8c39315eb9cb91dea.hip new file mode 100644 index 000000000000..0f7bab218f2b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9b50c6ebb27986ce5b378d8c39315eb9cb91dea.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9d2be18e2d53a5144f97dfdebb225fcb6d611d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9d2be18e2d53a5144f97dfdebb225fcb6d611d3.hip new file mode 100644 index 000000000000..707470e0d313 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9d2be18e2d53a5144f97dfdebb225fcb6d611d3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9df9ac4ee78e5f4d5bd0567e58a7090907c61e1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9df9ac4ee78e5f4d5bd0567e58a7090907c61e1.hip new file mode 100644 index 000000000000..99dbd992dfb1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9df9ac4ee78e5f4d5bd0567e58a7090907c61e1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9f00f270680de81df7737e848e0408cb070e68b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9f00f270680de81df7737e848e0408cb070e68b.hip new file mode 100644 index 000000000000..412714e9cc9c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_a9f00f270680de81df7737e848e0408cb070e68b.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa1041530f794c7b8dc4a8321ea0fcdd338fff35.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa1041530f794c7b8dc4a8321ea0fcdd338fff35.hip new file mode 100644 index 000000000000..7297f13a93d5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa1041530f794c7b8dc4a8321ea0fcdd338fff35.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa522b43c5e5ea69bcabb4c0fe28def2bd081a12.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa522b43c5e5ea69bcabb4c0fe28def2bd081a12.hip new file mode 100644 index 000000000000..6a4b0fd85e58 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa522b43c5e5ea69bcabb4c0fe28def2bd081a12.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa6d13b09f85ee62bb5018608812181fb43afc86.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa6d13b09f85ee62bb5018608812181fb43afc86.hip new file mode 100644 index 000000000000..0c54a98de49a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa6d13b09f85ee62bb5018608812181fb43afc86.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa82d20635e592edbf00439294835f6f39ad54a3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa82d20635e592edbf00439294835f6f39ad54a3.hip new file mode 100644 index 000000000000..d3e40eae7c10 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa82d20635e592edbf00439294835f6f39ad54a3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa996b9c843200a2ec33ed4319b48106cd7c6384.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa996b9c843200a2ec33ed4319b48106cd7c6384.hip new file mode 100644 index 000000000000..046c905dcc3d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aa996b9c843200a2ec33ed4319b48106cd7c6384.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aafe891dad43815e635f81225705ff944f990d75.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aafe891dad43815e635f81225705ff944f990d75.hip new file mode 100644 index 000000000000..a3a3500bf707 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aafe891dad43815e635f81225705ff944f990d75.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab09941bddfa9d61985b55f9b6bf0edec9bb89f6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab09941bddfa9d61985b55f9b6bf0edec9bb89f6.hip new file mode 100644 index 000000000000..e840e2f86c4f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab09941bddfa9d61985b55f9b6bf0edec9bb89f6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab0be5a2072b5e87f5ee58149688796b6513219f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab0be5a2072b5e87f5ee58149688796b6513219f.hip new file mode 100644 index 000000000000..06df0b0f6479 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab0be5a2072b5e87f5ee58149688796b6513219f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab0c3fe9529e24327686070731d0ac3ada76245e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab0c3fe9529e24327686070731d0ac3ada76245e.hip new file mode 100644 index 000000000000..efefcc88dbff --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab0c3fe9529e24327686070731d0ac3ada76245e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab1ca4ce061f7f69a250356f613cab00d1e2ac71.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab1ca4ce061f7f69a250356f613cab00d1e2ac71.hip new file mode 100644 index 000000000000..c36699649bcf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab1ca4ce061f7f69a250356f613cab00d1e2ac71.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab1d7f93427095e39bfc1d986b3d7fe54073ec75.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab1d7f93427095e39bfc1d986b3d7fe54073ec75.hip new file mode 100644 index 000000000000..908b3eebd844 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab1d7f93427095e39bfc1d986b3d7fe54073ec75.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab43f4a56c166dad0113f51b337a083f4df7cdb6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab43f4a56c166dad0113f51b337a083f4df7cdb6.hip new file mode 100644 index 000000000000..3f5d67fb6bca --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab43f4a56c166dad0113f51b337a083f4df7cdb6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab56e886d53a1d88fada0f10f00b9f398dc54568.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab56e886d53a1d88fada0f10f00b9f398dc54568.hip new file mode 100644 index 000000000000..10aef5450c16 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab56e886d53a1d88fada0f10f00b9f398dc54568.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab6cd5c9242f8278c8f3d9ce57b97d605c7e5a3e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab6cd5c9242f8278c8f3d9ce57b97d605c7e5a3e.hip new file mode 100644 index 000000000000..b7ff8d40e9a6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab6cd5c9242f8278c8f3d9ce57b97d605c7e5a3e.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab877ae2a1aab04498bf2b26b3fe99d6488ef151.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab877ae2a1aab04498bf2b26b3fe99d6488ef151.hip new file mode 100644 index 000000000000..48a337f18f7d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ab877ae2a1aab04498bf2b26b3fe99d6488ef151.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_abf6c6412f9853855b74a96e862935ddef66f763.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_abf6c6412f9853855b74a96e862935ddef66f763.hip new file mode 100644 index 000000000000..67d289a046a4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_abf6c6412f9853855b74a96e862935ddef66f763.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_abf92a5314fd33491b5eb6ebd2418b7e0d5db774.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_abf92a5314fd33491b5eb6ebd2418b7e0d5db774.hip new file mode 100644 index 000000000000..4981b113d3f9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_abf92a5314fd33491b5eb6ebd2418b7e0d5db774.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ac1ccde31b47e0e56ee0daab6403fed7895208c7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ac1ccde31b47e0e56ee0daab6403fed7895208c7.hip new file mode 100644 index 000000000000..e4cb72b31f63 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ac1ccde31b47e0e56ee0daab6403fed7895208c7.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ac5e9aee85cd16903bf7b82a4ac10402b0b26e22.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ac5e9aee85cd16903bf7b82a4ac10402b0b26e22.hip new file mode 100644 index 000000000000..d93079f7d8d9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ac5e9aee85cd16903bf7b82a4ac10402b0b26e22.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ac9382cf8bb56ffd962c99329bf67da992f8810d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ac9382cf8bb56ffd962c99329bf67da992f8810d.hip new file mode 100644 index 000000000000..5f0d722e5f0f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ac9382cf8bb56ffd962c99329bf67da992f8810d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aceb0641213e9a45ba48bcf72bb23845720d8b79.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aceb0641213e9a45ba48bcf72bb23845720d8b79.hip new file mode 100644 index 000000000000..5792beb0f34f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aceb0641213e9a45ba48bcf72bb23845720d8b79.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad091c69d19b27f7ad50ef6311532ad8b642a9c6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad091c69d19b27f7ad50ef6311532ad8b642a9c6.hip new file mode 100644 index 000000000000..25ecd7ae6e3b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad091c69d19b27f7ad50ef6311532ad8b642a9c6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad82071cc074fd30437f6158b5eb2c6df1f8c587.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad82071cc074fd30437f6158b5eb2c6df1f8c587.hip new file mode 100644 index 000000000000..df6ef5ff27d1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad82071cc074fd30437f6158b5eb2c6df1f8c587.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad989d2ce769f20e175fa88f4082c1c25fe03062.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad989d2ce769f20e175fa88f4082c1c25fe03062.hip new file mode 100644 index 000000000000..b02ddff6364b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad989d2ce769f20e175fa88f4082c1c25fe03062.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad9b99a194b59d3149842c15733394da275b12c0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad9b99a194b59d3149842c15733394da275b12c0.hip new file mode 100644 index 000000000000..ae5eb09a8666 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ad9b99a194b59d3149842c15733394da275b12c0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ada016be2bd0e377fbe01fa7adb9bbb8febce100.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ada016be2bd0e377fbe01fa7adb9bbb8febce100.hip new file mode 100644 index 000000000000..5e4c5035b79b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ada016be2bd0e377fbe01fa7adb9bbb8febce100.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adae2d4f8b2dac799e03ea6f279e6ecdf66f5381.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adae2d4f8b2dac799e03ea6f279e6ecdf66f5381.hip new file mode 100644 index 000000000000..0e59e52fbe85 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adae2d4f8b2dac799e03ea6f279e6ecdf66f5381.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adaef10ff2c5d89530310bdf1d53a194f06a94ef.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adaef10ff2c5d89530310bdf1d53a194f06a94ef.hip new file mode 100644 index 000000000000..4ecc422eb582 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adaef10ff2c5d89530310bdf1d53a194f06a94ef.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_add29e3e9828911a117dccaa5650e77805730d14.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_add29e3e9828911a117dccaa5650e77805730d14.hip new file mode 100644 index 000000000000..daee5524a5ab --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_add29e3e9828911a117dccaa5650e77805730d14.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adda7ad787524e3e47dcc1b65c41b2faea38f55f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adda7ad787524e3e47dcc1b65c41b2faea38f55f.hip new file mode 100644 index 000000000000..14419e0463f4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adda7ad787524e3e47dcc1b65c41b2faea38f55f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_addb6a14043c5a4df0f5042b3770b40c4e90795c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_addb6a14043c5a4df0f5042b3770b40c4e90795c.hip new file mode 100644 index 000000000000..29702caa8789 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_addb6a14043c5a4df0f5042b3770b40c4e90795c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adf160741a4f751d2f15d6eb23d4121cdca62b55.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adf160741a4f751d2f15d6eb23d4121cdca62b55.hip new file mode 100644 index 000000000000..cb0d58a3c82b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_adf160741a4f751d2f15d6eb23d4121cdca62b55.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae1ab1f4bbe86bb9bbc22e4774648076c321136f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae1ab1f4bbe86bb9bbc22e4774648076c321136f.hip new file mode 100644 index 000000000000..4689a93c01b0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae1ab1f4bbe86bb9bbc22e4774648076c321136f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae1afeb6cfdf860ff08e4c2f11c922fd5bfa621a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae1afeb6cfdf860ff08e4c2f11c922fd5bfa621a.hip new file mode 100644 index 000000000000..97ed99b87aeb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae1afeb6cfdf860ff08e4c2f11c922fd5bfa621a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae239476d61f48379754b97f29d7a285cc3192de.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae239476d61f48379754b97f29d7a285cc3192de.hip new file mode 100644 index 000000000000..b8ec64735f70 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae239476d61f48379754b97f29d7a285cc3192de.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae4e7253ad4873576052ec0a9400597bb7975753.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae4e7253ad4873576052ec0a9400597bb7975753.hip new file mode 100644 index 000000000000..7bf4ca0ae486 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae4e7253ad4873576052ec0a9400597bb7975753.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae4e80cb185759dd9b3eb3c67c239964b3694caa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae4e80cb185759dd9b3eb3c67c239964b3694caa.hip new file mode 100644 index 000000000000..8354935ff71c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae4e80cb185759dd9b3eb3c67c239964b3694caa.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae51b30c7e1cd30e550187458350c8db7c59a9ef.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae51b30c7e1cd30e550187458350c8db7c59a9ef.hip new file mode 100644 index 000000000000..b1c013312fdd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae51b30c7e1cd30e550187458350c8db7c59a9ef.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae7899b1ef159ecbf01f27014601eb79b31b49b3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae7899b1ef159ecbf01f27014601eb79b31b49b3.hip new file mode 100644 index 000000000000..60b4fff0c63c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae7899b1ef159ecbf01f27014601eb79b31b49b3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae87b1d5c50606430b544ed650d87df24366e7d5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae87b1d5c50606430b544ed650d87df24366e7d5.hip new file mode 100644 index 000000000000..94e2e5f888a1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae87b1d5c50606430b544ed650d87df24366e7d5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae8d0bdde763e617beafc0365ec4a3cd11df6c55.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae8d0bdde763e617beafc0365ec4a3cd11df6c55.hip new file mode 100644 index 000000000000..6fafb6b61507 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ae8d0bdde763e617beafc0365ec4a3cd11df6c55.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aebb2441e6cc1ccba4a391566e547402bcf7ced2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aebb2441e6cc1ccba4a391566e547402bcf7ced2.hip new file mode 100644 index 000000000000..df12fb8188b0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aebb2441e6cc1ccba4a391566e547402bcf7ced2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aebd5fed34ebceb879ae3dffaf58c7c04ab5fe80.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aebd5fed34ebceb879ae3dffaf58c7c04ab5fe80.hip new file mode 100644 index 000000000000..c457e60f89b1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aebd5fed34ebceb879ae3dffaf58c7c04ab5fe80.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aebff7e6605b273bad844b8f70ef031625bff48e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aebff7e6605b273bad844b8f70ef031625bff48e.hip new file mode 100644 index 000000000000..7856926c6916 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aebff7e6605b273bad844b8f70ef031625bff48e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aec87e65afa93e84d7a947c52f291c1c7360033c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aec87e65afa93e84d7a947c52f291c1c7360033c.hip new file mode 100644 index 000000000000..dc0553254864 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aec87e65afa93e84d7a947c52f291c1c7360033c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aece14f7a220222eb4ce6783ec2b9fce6fde94b8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aece14f7a220222eb4ce6783ec2b9fce6fde94b8.hip new file mode 100644 index 000000000000..e62cc9116e21 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_aece14f7a220222eb4ce6783ec2b9fce6fde94b8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_af06c0dae15684f83e15722a4c07342af9ea011c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_af06c0dae15684f83e15722a4c07342af9ea011c.hip new file mode 100644 index 000000000000..70a2897b4a08 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_af06c0dae15684f83e15722a4c07342af9ea011c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_af6ccfa11add1ae49888337e84d9c446d2f67da4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_af6ccfa11add1ae49888337e84d9c446d2f67da4.hip new file mode 100644 index 000000000000..f83afc6ed974 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_af6ccfa11add1ae49888337e84d9c446d2f67da4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afadc4f76e237514db0bc0203102297b79730bd0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afadc4f76e237514db0bc0203102297b79730bd0.hip new file mode 100644 index 000000000000..1f003f7d4fe3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afadc4f76e237514db0bc0203102297b79730bd0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afc4b47a6fa62a4ca5cff6a7e01c9f6b371d2215.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afc4b47a6fa62a4ca5cff6a7e01c9f6b371d2215.hip new file mode 100644 index 000000000000..b00cacf498f8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afc4b47a6fa62a4ca5cff6a7e01c9f6b371d2215.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + true, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afcafd07c1f56e74373ccf37db35976023456d50.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afcafd07c1f56e74373ccf37db35976023456d50.hip new file mode 100644 index 000000000000..939fe3065517 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afcafd07c1f56e74373ccf37db35976023456d50.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afccf699f593c828e11efc053b144044e45b32d6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afccf699f593c828e11efc053b144044e45b32d6.hip new file mode 100644 index 000000000000..ac8ba2d93193 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afccf699f593c828e11efc053b144044e45b32d6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afda8f46b5ded4c2aa9d722fec17b75004b59f7d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afda8f46b5ded4c2aa9d722fec17b75004b59f7d.hip new file mode 100644 index 000000000000..ec55820cb675 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afda8f46b5ded4c2aa9d722fec17b75004b59f7d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afdab954fd111ec48721f25710d61c0c8affd8db.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afdab954fd111ec48721f25710d61c0c8affd8db.hip new file mode 100644 index 000000000000..622ee71147a1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_afdab954fd111ec48721f25710d61c0c8affd8db.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b00e062055933388e37525df5766f3c14cd3538a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b00e062055933388e37525df5766f3c14cd3538a.hip new file mode 100644 index 000000000000..c722c77e259a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b00e062055933388e37525df5766f3c14cd3538a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b01dc872c24db4db0c9179fc07e17f41060390de.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b01dc872c24db4db0c9179fc07e17f41060390de.hip new file mode 100644 index 000000000000..c176d4066793 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b01dc872c24db4db0c9179fc07e17f41060390de.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b03ab68e33844f97aa58d463e00037bc11c50da0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b03ab68e33844f97aa58d463e00037bc11c50da0.hip new file mode 100644 index 000000000000..60cc872617da --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b03ab68e33844f97aa58d463e00037bc11c50da0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b04f14f829eff73afaa57a875f74ebd1e6860979.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b04f14f829eff73afaa57a875f74ebd1e6860979.hip new file mode 100644 index 000000000000..12c685195aae --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b04f14f829eff73afaa57a875f74ebd1e6860979.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b0544a38dfdf4d81dc95894387845f48435e299a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b0544a38dfdf4d81dc95894387845f48435e299a.hip new file mode 100644 index 000000000000..77a4df5b6528 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b0544a38dfdf4d81dc95894387845f48435e299a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b0dd965d5d9080ed5c6a04b7eea9890f3a264f20.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b0dd965d5d9080ed5c6a04b7eea9890f3a264f20.hip new file mode 100644 index 000000000000..916f9471868d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b0dd965d5d9080ed5c6a04b7eea9890f3a264f20.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b0f555b74ed36f1bef8f47880b3edc6760f27788.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b0f555b74ed36f1bef8f47880b3edc6760f27788.hip new file mode 100644 index 000000000000..f7ff6b9b35f7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b0f555b74ed36f1bef8f47880b3edc6760f27788.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b1766695dbb790bd614b83dc7569ad449404cc89.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b1766695dbb790bd614b83dc7569ad449404cc89.hip new file mode 100644 index 000000000000..1f320e33ac50 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b1766695dbb790bd614b83dc7569ad449404cc89.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b18a615e66d7cd739ce35412811359a03cb23a8e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b18a615e66d7cd739ce35412811359a03cb23a8e.hip new file mode 100644 index 000000000000..1651f3cb839e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b18a615e66d7cd739ce35412811359a03cb23a8e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b192c55f002d8540d5f965cc4df0c2e33f4b9ff9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b192c55f002d8540d5f965cc4df0c2e33f4b9ff9.hip new file mode 100644 index 000000000000..dbe426a18306 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b192c55f002d8540d5f965cc4df0c2e33f4b9ff9.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 64, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b19f05f6848403480ba41d37cdbf44ccca1b1f8d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b19f05f6848403480ba41d37cdbf44ccca1b1f8d.hip new file mode 100644 index 000000000000..2cac2fc3b4e4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b19f05f6848403480ba41d37cdbf44ccca1b1f8d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b1ad101ce91348266d3885afdf2996a0fdb72135.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b1ad101ce91348266d3885afdf2996a0fdb72135.hip new file mode 100644 index 000000000000..05d80da75242 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b1ad101ce91348266d3885afdf2996a0fdb72135.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b1c5d55d47d6038e9162d32ac968ff58c0942938.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b1c5d55d47d6038e9162d32ac968ff58c0942938.hip new file mode 100644 index 000000000000..bd5640a71a9f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b1c5d55d47d6038e9162d32ac968ff58c0942938.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b20c6252863a73341b0010191fad4c834860f884.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b20c6252863a73341b0010191fad4c834860f884.hip new file mode 100644 index 000000000000..5aefb14dcc0f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b20c6252863a73341b0010191fad4c834860f884.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b20e314642cf565e4f32bceffdb5c0e653ab627b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b20e314642cf565e4f32bceffdb5c0e653ab627b.hip new file mode 100644 index 000000000000..a5ccdcb12c63 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b20e314642cf565e4f32bceffdb5c0e653ab627b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b24f91dec2029b25d0d96962528410df55a468ed.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b24f91dec2029b25d0d96962528410df55a468ed.hip new file mode 100644 index 000000000000..dc448a1795d7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b24f91dec2029b25d0d96962528410df55a468ed.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b285e2f1970b78e18002464eeda63798229bbc3a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b285e2f1970b78e18002464eeda63798229bbc3a.hip new file mode 100644 index 000000000000..6915e2071af9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b285e2f1970b78e18002464eeda63798229bbc3a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b298e213f927b518c693660110f08bdd94990ef0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b298e213f927b518c693660110f08bdd94990ef0.hip new file mode 100644 index 000000000000..76d83307c7ff --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b298e213f927b518c693660110f08bdd94990ef0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b2af5f5b5ee3ae964824a3e9c7bbeb5bb39c557c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b2af5f5b5ee3ae964824a3e9c7bbeb5bb39c557c.hip new file mode 100644 index 000000000000..708b4cee4924 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b2af5f5b5ee3ae964824a3e9c7bbeb5bb39c557c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b2f91e937b427ecc932c0cb0c90b2c2378db0be6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b2f91e937b427ecc932c0cb0c90b2c2378db0be6.hip new file mode 100644 index 000000000000..e70029e13638 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b2f91e937b427ecc932c0cb0c90b2c2378db0be6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3063d06723ac70c5f8802ab49c5c35e1debf56e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3063d06723ac70c5f8802ab49c5c35e1debf56e.hip new file mode 100644 index 000000000000..0d0e15072ae1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3063d06723ac70c5f8802ab49c5c35e1debf56e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b31f56244076c501cb09b4b90975132cae4c4386.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b31f56244076c501cb09b4b90975132cae4c4386.hip new file mode 100644 index 000000000000..1148242e7bd4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b31f56244076c501cb09b4b90975132cae4c4386.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3486244e0b7d6dbcaa1951e8b8883ce441c3f99.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3486244e0b7d6dbcaa1951e8b8883ce441c3f99.hip new file mode 100644 index 000000000000..4224bf6dcfa6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3486244e0b7d6dbcaa1951e8b8883ce441c3f99.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b34c1ce348c3d9cdf6bbec9758de9d5fe94c43fc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b34c1ce348c3d9cdf6bbec9758de9d5fe94c43fc.hip new file mode 100644 index 000000000000..f9007c8d0a30 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b34c1ce348c3d9cdf6bbec9758de9d5fe94c43fc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b38a1d3cffae01332a3a9d9472ff1b2c443e82af.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b38a1d3cffae01332a3a9d9472ff1b2c443e82af.hip new file mode 100644 index 000000000000..3af4858be13e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b38a1d3cffae01332a3a9d9472ff1b2c443e82af.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3a104733f678193068d8642d6560faa03897258.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3a104733f678193068d8642d6560faa03897258.hip new file mode 100644 index 000000000000..0412b0d6b1a9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3a104733f678193068d8642d6560faa03897258.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3da22d3482738a8474ae15e8e5fca9020c4e195.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3da22d3482738a8474ae15e8e5fca9020c4e195.hip new file mode 100644 index 000000000000..9c8c668bd2b7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b3da22d3482738a8474ae15e8e5fca9020c4e195.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b41735d250b5a16967281a5f07873b9cde3df4d6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b41735d250b5a16967281a5f07873b9cde3df4d6.hip new file mode 100644 index 000000000000..e5e19dc68d87 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b41735d250b5a16967281a5f07873b9cde3df4d6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b41a30092e8138877c1f6c25656e0f8ae2c2444e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b41a30092e8138877c1f6c25656e0f8ae2c2444e.hip new file mode 100644 index 000000000000..50a898f1adb2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b41a30092e8138877c1f6c25656e0f8ae2c2444e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b41ea5293bc1c56efa2c4b5681d965aa6f2ce6c3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b41ea5293bc1c56efa2c4b5681d965aa6f2ce6c3.hip new file mode 100644 index 000000000000..535f441e9d43 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b41ea5293bc1c56efa2c4b5681d965aa6f2ce6c3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4588379eaa268d79fe8f8e4457b009f204a5fb7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4588379eaa268d79fe8f8e4457b009f204a5fb7.hip new file mode 100644 index 000000000000..568fb2a61a68 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4588379eaa268d79fe8f8e4457b009f204a5fb7.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b493c99888d82cd2852bfb101f99a2e6a27665b8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b493c99888d82cd2852bfb101f99a2e6a27665b8.hip new file mode 100644 index 000000000000..b1a203138262 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b493c99888d82cd2852bfb101f99a2e6a27665b8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4a5715b550f67b8870ba66e1e6282a26cc1dbf3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4a5715b550f67b8870ba66e1e6282a26cc1dbf3.hip new file mode 100644 index 000000000000..46e3bb80c18b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4a5715b550f67b8870ba66e1e6282a26cc1dbf3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4b037a2e262d11d3ed7d9feeb41b9e05427a739.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4b037a2e262d11d3ed7d9feeb41b9e05427a739.hip new file mode 100644 index 000000000000..82f28e19893a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4b037a2e262d11d3ed7d9feeb41b9e05427a739.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4bd2d206ceb237ed2c51f58abb5cbf96e39d07b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4bd2d206ceb237ed2c51f58abb5cbf96e39d07b.hip new file mode 100644 index 000000000000..e2273c31b333 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4bd2d206ceb237ed2c51f58abb5cbf96e39d07b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4ec377c44ac18527ca6a01bc3b146706a6e1e09.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4ec377c44ac18527ca6a01bc3b146706a6e1e09.hip new file mode 100644 index 000000000000..44950e6110f0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4ec377c44ac18527ca6a01bc3b146706a6e1e09.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4f12f10d7b968e0d8e7c23f36d3a360de74a905.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4f12f10d7b968e0d8e7c23f36d3a360de74a905.hip new file mode 100644 index 000000000000..923f41381c9c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b4f12f10d7b968e0d8e7c23f36d3a360de74a905.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b50e6df20a2426abd3d2ff2262a37c009196024c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b50e6df20a2426abd3d2ff2262a37c009196024c.hip new file mode 100644 index 000000000000..9a670ef26e1a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b50e6df20a2426abd3d2ff2262a37c009196024c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b513834918d5ea789e2db21abece7c2d3532a7e7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b513834918d5ea789e2db21abece7c2d3532a7e7.hip new file mode 100644 index 000000000000..32b5404e47e9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b513834918d5ea789e2db21abece7c2d3532a7e7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5248f443a12d96815c04409a00102923c717023.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5248f443a12d96815c04409a00102923c717023.hip new file mode 100644 index 000000000000..3998da16612c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5248f443a12d96815c04409a00102923c717023.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5371415448fffffd58bf014dac9f4876153657b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5371415448fffffd58bf014dac9f4876153657b.hip new file mode 100644 index 000000000000..957890d6bceb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5371415448fffffd58bf014dac9f4876153657b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5ac596c636df55e81293228cbc53dcbb3024e5a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5ac596c636df55e81293228cbc53dcbb3024e5a.hip new file mode 100644 index 000000000000..0d423090af9f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5ac596c636df55e81293228cbc53dcbb3024e5a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5ba2e73df35f6e0f7317303823fde92a42b1a35.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5ba2e73df35f6e0f7317303823fde92a42b1a35.hip new file mode 100644 index 000000000000..c027189593b0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5ba2e73df35f6e0f7317303823fde92a42b1a35.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5bccc85f74f54a2ceb17fe3040b04fe306c53f9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5bccc85f74f54a2ceb17fe3040b04fe306c53f9.hip new file mode 100644 index 000000000000..a8c37b185861 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5bccc85f74f54a2ceb17fe3040b04fe306c53f9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5c3131fb8e5a25bd4a14bc9075eb6fa01b61d02.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5c3131fb8e5a25bd4a14bc9075eb6fa01b61d02.hip new file mode 100644 index 000000000000..ea9eb74433cf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5c3131fb8e5a25bd4a14bc9075eb6fa01b61d02.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5c7fca1f76a31b0390e92d90d569fab94d4f783.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5c7fca1f76a31b0390e92d90d569fab94d4f783.hip new file mode 100644 index 000000000000..0c2d7bb88b24 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5c7fca1f76a31b0390e92d90d569fab94d4f783.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5db3d5b1d8af89381fc4b8073f84c5fa25fdef5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5db3d5b1d8af89381fc4b8073f84c5fa25fdef5.hip new file mode 100644 index 000000000000..1eb209d714a2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b5db3d5b1d8af89381fc4b8073f84c5fa25fdef5.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + false, + false, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b60a4e87a7aabfe3c1ce02b408522f3ec862e3d7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b60a4e87a7aabfe3c1ce02b408522f3ec862e3d7.hip new file mode 100644 index 000000000000..19c0a9cc05bc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b60a4e87a7aabfe3c1ce02b408522f3ec862e3d7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b6b17ae67adee9e56a022cd2a5514fb9c4e99920.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b6b17ae67adee9e56a022cd2a5514fb9c4e99920.hip new file mode 100644 index 000000000000..1671e0dd8936 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b6b17ae67adee9e56a022cd2a5514fb9c4e99920.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b72a804bb3c99830653d41ac0bd49943c801b89a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b72a804bb3c99830653d41ac0bd49943c801b89a.hip new file mode 100644 index 000000000000..ad7cb8896663 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b72a804bb3c99830653d41ac0bd49943c801b89a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b737410b404a51043fc3bd503c0b107c297e4c9f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b737410b404a51043fc3bd503c0b107c297e4c9f.hip new file mode 100644 index 000000000000..10f202afe22d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b737410b404a51043fc3bd503c0b107c297e4c9f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b75843bb13058ffe29251e053800c509c7590544.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b75843bb13058ffe29251e053800c509c7590544.hip new file mode 100644 index 000000000000..6d508d109ce7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b75843bb13058ffe29251e053800c509c7590544.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b774450ebadaacf23e944aaf8ca90eada01e8a5a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b774450ebadaacf23e944aaf8ca90eada01e8a5a.hip new file mode 100644 index 000000000000..c36d88e31245 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b774450ebadaacf23e944aaf8ca90eada01e8a5a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b779cc0b0380e1e6a2b51fc6216fdd72215b882b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b779cc0b0380e1e6a2b51fc6216fdd72215b882b.hip new file mode 100644 index 000000000000..e27e532203df --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b779cc0b0380e1e6a2b51fc6216fdd72215b882b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b7a03ab0b7887cc7ed0cb40e56360a8d36c0bb8e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b7a03ab0b7887cc7ed0cb40e56360a8d36c0bb8e.hip new file mode 100644 index 000000000000..6488aeb1bc30 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b7a03ab0b7887cc7ed0cb40e56360a8d36c0bb8e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b80d0828ba6d24ea3c1a97bd9835ee937b4b32fb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b80d0828ba6d24ea3c1a97bd9835ee937b4b32fb.hip new file mode 100644 index 000000000000..62623e588e4a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b80d0828ba6d24ea3c1a97bd9835ee937b4b32fb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b872f9e6ebe330cc1818ea82b53acec79a2f672c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b872f9e6ebe330cc1818ea82b53acec79a2f672c.hip new file mode 100644 index 000000000000..ea3a4fab3919 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b872f9e6ebe330cc1818ea82b53acec79a2f672c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b8fbc6f6e9c515edce3c7a438b3bc308b30d3857.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b8fbc6f6e9c515edce3c7a438b3bc308b30d3857.hip new file mode 100644 index 000000000000..d826af150427 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b8fbc6f6e9c515edce3c7a438b3bc308b30d3857.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9385db12001110c42eff6aabad935a69ad3afe2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9385db12001110c42eff6aabad935a69ad3afe2.hip new file mode 100644 index 000000000000..288c0b2167da --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9385db12001110c42eff6aabad935a69ad3afe2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9559dd36a0a4f5e068a722e285f485137bd5ef0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9559dd36a0a4f5e068a722e285f485137bd5ef0.hip new file mode 100644 index 000000000000..26ec43c91b0a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9559dd36a0a4f5e068a722e285f485137bd5ef0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9627f9c8d0088df0364a64643f2b5dcd951f2bb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9627f9c8d0088df0364a64643f2b5dcd951f2bb.hip new file mode 100644 index 000000000000..0e201d264535 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9627f9c8d0088df0364a64643f2b5dcd951f2bb.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9a742ceeb6736a2c8f9439d0b05e10d3e0c5c6f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9a742ceeb6736a2c8f9439d0b05e10d3e0c5c6f.hip new file mode 100644 index 000000000000..ed60152f8ac6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9a742ceeb6736a2c8f9439d0b05e10d3e0c5c6f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9baf70220079e6d4e87eb01a7259923d8a01e29.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9baf70220079e6d4e87eb01a7259923d8a01e29.hip new file mode 100644 index 000000000000..55e65309e5ed --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9baf70220079e6d4e87eb01a7259923d8a01e29.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9d00ab8373747a5c6b9d2f8dd50ceb14db4163c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9d00ab8373747a5c6b9d2f8dd50ceb14db4163c.hip new file mode 100644 index 000000000000..8f6303a1c6c3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9d00ab8373747a5c6b9d2f8dd50ceb14db4163c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9ed0a64deb55616646ea98b21a891c971cd98ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9ed0a64deb55616646ea98b21a891c971cd98ad.hip new file mode 100644 index 000000000000..e313cb363ef9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_b9ed0a64deb55616646ea98b21a891c971cd98ad.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ba145535e53899fe127987aa854f81234a9c51c4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ba145535e53899fe127987aa854f81234a9c51c4.hip new file mode 100644 index 000000000000..0a7503140638 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ba145535e53899fe127987aa854f81234a9c51c4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ba8b09f0aaa40a7c9ad5f0458b460d3e328f3c74.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ba8b09f0aaa40a7c9ad5f0458b460d3e328f3c74.hip new file mode 100644 index 000000000000..32a61c6a29ca --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ba8b09f0aaa40a7c9ad5f0458b460d3e328f3c74.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bafbef3f13d429ec3e9f4672218998d5669d79f2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bafbef3f13d429ec3e9f4672218998d5669d79f2.hip new file mode 100644 index 000000000000..e102afa7c305 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bafbef3f13d429ec3e9f4672218998d5669d79f2.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bb111b7acc269f8d5e70915d3efde4c425aa5f5c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bb111b7acc269f8d5e70915d3efde4c425aa5f5c.hip new file mode 100644 index 000000000000..697f8d6f393d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bb111b7acc269f8d5e70915d3efde4c425aa5f5c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bb28a4e95723e3df380f98b5ac107c4df353850b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bb28a4e95723e3df380f98b5ac107c4df353850b.hip new file mode 100644 index 000000000000..0c0d4e1400a7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bb28a4e95723e3df380f98b5ac107c4df353850b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bb35c86443cc9ea38c06ebc0656306483c95ef67.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bb35c86443cc9ea38c06ebc0656306483c95ef67.hip new file mode 100644 index 000000000000..b9c554b89b12 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bb35c86443cc9ea38c06ebc0656306483c95ef67.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bba10ecb79ede07324e1198a71a95ff26e9eb235.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bba10ecb79ede07324e1198a71a95ff26e9eb235.hip new file mode 100644 index 000000000000..ea4a38d7df17 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bba10ecb79ede07324e1198a71a95ff26e9eb235.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bbe23201fbebed25781f249e5c77c31e0e7f9ddb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bbe23201fbebed25781f249e5c77c31e0e7f9ddb.hip new file mode 100644 index 000000000000..5f464dfd0cae --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bbe23201fbebed25781f249e5c77c31e0e7f9ddb.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bbfd025488e52b97c04995c4c5faff371b77e4d6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bbfd025488e52b97c04995c4c5faff371b77e4d6.hip new file mode 100644 index 000000000000..b2ea0970aa67 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bbfd025488e52b97c04995c4c5faff371b77e4d6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc1ae1dddb8cc5d78196da6b26ebe66c1ce7e567.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc1ae1dddb8cc5d78196da6b26ebe66c1ce7e567.hip new file mode 100644 index 000000000000..96a10e083801 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc1ae1dddb8cc5d78196da6b26ebe66c1ce7e567.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc238fd2095b26a167b41cdec8280182330b7b25.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc238fd2095b26a167b41cdec8280182330b7b25.hip new file mode 100644 index 000000000000..8db90faeefc7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc238fd2095b26a167b41cdec8280182330b7b25.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc4425e30a0b17e8b31726817e8d3177b5c51934.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc4425e30a0b17e8b31726817e8d3177b5c51934.hip new file mode 100644 index 000000000000..b0da0f983572 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc4425e30a0b17e8b31726817e8d3177b5c51934.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc4e0f0496a34d2fb43c80ce0162ad4183f29064.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc4e0f0496a34d2fb43c80ce0162ad4183f29064.hip new file mode 100644 index 000000000000..cde5025b2042 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc4e0f0496a34d2fb43c80ce0162ad4183f29064.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc6ce17223d8d83a64b8c96ac88223e4441a4692.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc6ce17223d8d83a64b8c96ac88223e4441a4692.hip new file mode 100644 index 000000000000..f52451423cd8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc6ce17223d8d83a64b8c96ac88223e4441a4692.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc744db85d4237ee9640f1658e0caab7648e3bb6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc744db85d4237ee9640f1658e0caab7648e3bb6.hip new file mode 100644 index 000000000000..c75edd4b6482 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc744db85d4237ee9640f1658e0caab7648e3bb6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc79e255d25744725e2a9db9f90d5cc2b8a0e0c1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc79e255d25744725e2a9db9f90d5cc2b8a0e0c1.hip new file mode 100644 index 000000000000..c96d54733ac6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc79e255d25744725e2a9db9f90d5cc2b8a0e0c1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc897852a4ca992961843144f4ec4f8b86dd5e9d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc897852a4ca992961843144f4ec4f8b86dd5e9d.hip new file mode 100644 index 000000000000..a10f2dcd0b0d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bc897852a4ca992961843144f4ec4f8b86dd5e9d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bcb6f0730fd09b4c6c60913425927dfdb8f83d82.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bcb6f0730fd09b4c6c60913425927dfdb8f83d82.hip new file mode 100644 index 000000000000..ffb2ee417869 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bcb6f0730fd09b4c6c60913425927dfdb8f83d82.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bcd7ccdceb7baf3b986f2a0248827822a5f72e47.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bcd7ccdceb7baf3b986f2a0248827822a5f72e47.hip new file mode 100644 index 000000000000..61f98c5dbada --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bcd7ccdceb7baf3b986f2a0248827822a5f72e47.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bcf8836c8cf932cc2748e313885003f0e11a887f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bcf8836c8cf932cc2748e313885003f0e11a887f.hip new file mode 100644 index 000000000000..acdbebfd5c1c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bcf8836c8cf932cc2748e313885003f0e11a887f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd064e302ff5b983dbdb4ccf51383fb29ddff44f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd064e302ff5b983dbdb4ccf51383fb29ddff44f.hip new file mode 100644 index 000000000000..48d33029a3bd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd064e302ff5b983dbdb4ccf51383fb29ddff44f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd28203f47b6a48e9b66302cf8312f3796ca500c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd28203f47b6a48e9b66302cf8312f3796ca500c.hip new file mode 100644 index 000000000000..bd8ca3a01d5f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd28203f47b6a48e9b66302cf8312f3796ca500c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd37f4f7914805a97d5073f1ebf8a8b8c2648d31.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd37f4f7914805a97d5073f1ebf8a8b8c2648d31.hip new file mode 100644 index 000000000000..ea4b4b24aabc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd37f4f7914805a97d5073f1ebf8a8b8c2648d31.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd3daa5f99b4522d932334924347353ce2854821.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd3daa5f99b4522d932334924347353ce2854821.hip new file mode 100644 index 000000000000..d7575f04ec46 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd3daa5f99b4522d932334924347353ce2854821.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd6aa39d0ae3c87d011610cdb5e2e317f337c454.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd6aa39d0ae3c87d011610cdb5e2e317f337c454.hip new file mode 100644 index 000000000000..6a2de30bb01f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd6aa39d0ae3c87d011610cdb5e2e317f337c454.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd80a1774d8b7d8bee4e8663392b97cda11dcbf5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd80a1774d8b7d8bee4e8663392b97cda11dcbf5.hip new file mode 100644 index 000000000000..aeb274f09f1f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd80a1774d8b7d8bee4e8663392b97cda11dcbf5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd8bf7c572c1984ca3061062cf3c31d993f6762d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd8bf7c572c1984ca3061062cf3c31d993f6762d.hip new file mode 100644 index 000000000000..f32a6421d835 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd8bf7c572c1984ca3061062cf3c31d993f6762d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd9c47f3305e47db6ab6bc627fb3d80269633074.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd9c47f3305e47db6ab6bc627fb3d80269633074.hip new file mode 100644 index 000000000000..f941e72f6171 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bd9c47f3305e47db6ab6bc627fb3d80269633074.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bdab172627718278a71a93e3737ef08ad9259a4f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bdab172627718278a71a93e3737ef08ad9259a4f.hip new file mode 100644 index 000000000000..64bacecc2cd8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bdab172627718278a71a93e3737ef08ad9259a4f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bde24a8dbe6add6f2dd2beb48b1280f3a84a9b2a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bde24a8dbe6add6f2dd2beb48b1280f3a84a9b2a.hip new file mode 100644 index 000000000000..1dcb86091fbe --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bde24a8dbe6add6f2dd2beb48b1280f3a84a9b2a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_be1e1533fc37b41838bd37edc2b6d2f2e76ae1c6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_be1e1533fc37b41838bd37edc2b6d2f2e76ae1c6.hip new file mode 100644 index 000000000000..59c5142497f3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_be1e1533fc37b41838bd37edc2b6d2f2e76ae1c6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_be4dd90ccb2f258029d0156cf23f940b694cf08d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_be4dd90ccb2f258029d0156cf23f940b694cf08d.hip new file mode 100644 index 000000000000..ad15084e4cc9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_be4dd90ccb2f258029d0156cf23f940b694cf08d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_be8ec1163a01b9cd9a802d8b44669e8770c20234.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_be8ec1163a01b9cd9a802d8b44669e8770c20234.hip new file mode 100644 index 000000000000..588646503161 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_be8ec1163a01b9cd9a802d8b44669e8770c20234.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_beae876d6da465687f162136231f15767cc7bb14.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_beae876d6da465687f162136231f15767cc7bb14.hip new file mode 100644 index 000000000000..2780b0d568bb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_beae876d6da465687f162136231f15767cc7bb14.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_beb9afccc15de7dfcb2e7d898abc0d61201de73e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_beb9afccc15de7dfcb2e7d898abc0d61201de73e.hip new file mode 100644 index 000000000000..a1619cab297d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_beb9afccc15de7dfcb2e7d898abc0d61201de73e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bec30e7107c5dce3fe6aa87d83ed96da75478da0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bec30e7107c5dce3fe6aa87d83ed96da75478da0.hip new file mode 100644 index 000000000000..0957d6a9b1f2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bec30e7107c5dce3fe6aa87d83ed96da75478da0.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + true, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bec9e4c0317e8d351f60258ed6611fbf365c4024.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bec9e4c0317e8d351f60258ed6611fbf365c4024.hip new file mode 100644 index 000000000000..46512ea2b5ad --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bec9e4c0317e8d351f60258ed6611fbf365c4024.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_becc2a4d7ac045365300bf8bd45fc6d3e1e1c8b1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_becc2a4d7ac045365300bf8bd45fc6d3e1e1c8b1.hip new file mode 100644 index 000000000000..389830690d07 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_becc2a4d7ac045365300bf8bd45fc6d3e1e1c8b1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bed5a8c5cf683f6dfaefad72c2e2f5c2f2b2732f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bed5a8c5cf683f6dfaefad72c2e2f5c2f2b2732f.hip new file mode 100644 index 000000000000..d33106b0be7a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bed5a8c5cf683f6dfaefad72c2e2f5c2f2b2732f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bef3bd014a918feddadc98eed92a7734f9bcd890.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bef3bd014a918feddadc98eed92a7734f9bcd890.hip new file mode 100644 index 000000000000..3e2a3ee96e88 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bef3bd014a918feddadc98eed92a7734f9bcd890.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bf9cdf86a7944cd690b0fcbbaec235863acd10bb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bf9cdf86a7944cd690b0fcbbaec235863acd10bb.hip new file mode 100644 index 000000000000..356a97e25a79 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_bf9cdf86a7944cd690b0fcbbaec235863acd10bb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0338fbc05f86270ded7df2bd3e2758a03961b62.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0338fbc05f86270ded7df2bd3e2758a03961b62.hip new file mode 100644 index 000000000000..748c01991b59 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0338fbc05f86270ded7df2bd3e2758a03961b62.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0342686e4efd26413c6719782ed13603479c4e0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0342686e4efd26413c6719782ed13603479c4e0.hip new file mode 100644 index 000000000000..54fc606a0bbc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0342686e4efd26413c6719782ed13603479c4e0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c063318cb851ccaa923be12d34c84d839bc64bb8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c063318cb851ccaa923be12d34c84d839bc64bb8.hip new file mode 100644 index 000000000000..f8b56fbd1b00 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c063318cb851ccaa923be12d34c84d839bc64bb8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c08095341ca7e3a1debeb780c1878e351692bee2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c08095341ca7e3a1debeb780c1878e351692bee2.hip new file mode 100644 index 000000000000..3374fc97d095 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c08095341ca7e3a1debeb780c1878e351692bee2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0a3c4ac0a50bb9b7ad764929dbee98c856b1210.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0a3c4ac0a50bb9b7ad764929dbee98c856b1210.hip new file mode 100644 index 000000000000..aa55230f9481 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0a3c4ac0a50bb9b7ad764929dbee98c856b1210.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0f76aff077c28f8afd7b22f284cf2894e08a043.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0f76aff077c28f8afd7b22f284cf2894e08a043.hip new file mode 100644 index 000000000000..7050da910988 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c0f76aff077c28f8afd7b22f284cf2894e08a043.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c112c01d201c366bdd7acccf2e1b18b00f671153.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c112c01d201c366bdd7acccf2e1b18b00f671153.hip new file mode 100644 index 000000000000..8d8c07d01fe6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c112c01d201c366bdd7acccf2e1b18b00f671153.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c11d68fe766fc753c657362673704005b538660b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c11d68fe766fc753c657362673704005b538660b.hip new file mode 100644 index 000000000000..e54ab7e0dd0e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c11d68fe766fc753c657362673704005b538660b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c137c03bf161b2ec6a9a046fa49d7bbf80ae47b8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c137c03bf161b2ec6a9a046fa49d7bbf80ae47b8.hip new file mode 100644 index 000000000000..c286c16ba83d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c137c03bf161b2ec6a9a046fa49d7bbf80ae47b8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c197d1f050f42d82e6851fa286db6f81ba197f40.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c197d1f050f42d82e6851fa286db6f81ba197f40.hip new file mode 100644 index 000000000000..621c782e0e90 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c197d1f050f42d82e6851fa286db6f81ba197f40.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1b76bc7a17f573c0d52c07ae9ff4302662ae61f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1b76bc7a17f573c0d52c07ae9ff4302662ae61f.hip new file mode 100644 index 000000000000..9d1e6d6551c1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1b76bc7a17f573c0d52c07ae9ff4302662ae61f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1b94e19d762ddc33cc4e94c6675d93cbde21e3d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1b94e19d762ddc33cc4e94c6675d93cbde21e3d.hip new file mode 100644 index 000000000000..a83be7a1699d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1b94e19d762ddc33cc4e94c6675d93cbde21e3d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1f40c3421b9ad8cf43940530ec50bcf620058f2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1f40c3421b9ad8cf43940530ec50bcf620058f2.hip new file mode 100644 index 000000000000..1da958d92095 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1f40c3421b9ad8cf43940530ec50bcf620058f2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1f721a330b2d0fac13b22061616d7b10c0f91e9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1f721a330b2d0fac13b22061616d7b10c0f91e9.hip new file mode 100644 index 000000000000..5917659c09ca --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c1f721a330b2d0fac13b22061616d7b10c0f91e9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c250ea59ab6e1ee39cce15cbd3f181047cdee31a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c250ea59ab6e1ee39cce15cbd3f181047cdee31a.hip new file mode 100644 index 000000000000..21c71b806abf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c250ea59ab6e1ee39cce15cbd3f181047cdee31a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2541b6b5cf27de3f45f60671d36602f07ce1783.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2541b6b5cf27de3f45f60671d36602f07ce1783.hip new file mode 100644 index 000000000000..73c4cb917dff --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2541b6b5cf27de3f45f60671d36602f07ce1783.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c27b3026f1dc3056dee3a3e64bf31c45683607c9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c27b3026f1dc3056dee3a3e64bf31c45683607c9.hip new file mode 100644 index 000000000000..d360841e5b2e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c27b3026f1dc3056dee3a3e64bf31c45683607c9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c28de8f96c8315877031a2d56261e95fee6aef44.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c28de8f96c8315877031a2d56261e95fee6aef44.hip new file mode 100644 index 000000000000..e7ad1d9e903b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c28de8f96c8315877031a2d56261e95fee6aef44.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c29110dd501853e87ebc122dd1971b0bb1bcd92f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c29110dd501853e87ebc122dd1971b0bb1bcd92f.hip new file mode 100644 index 000000000000..81aab7917a45 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c29110dd501853e87ebc122dd1971b0bb1bcd92f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2940fd05efd52bdf8a3f9aa4b78bde9b5809b34.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2940fd05efd52bdf8a3f9aa4b78bde9b5809b34.hip new file mode 100644 index 000000000000..e645a9cf03f8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2940fd05efd52bdf8a3f9aa4b78bde9b5809b34.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2a2856bf9a81544a30d535a13554e3a8107c476.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2a2856bf9a81544a30d535a13554e3a8107c476.hip new file mode 100644 index 000000000000..61bcb0c3a4d6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2a2856bf9a81544a30d535a13554e3a8107c476.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2b719893a4d8a1e71857966d399f06c0a41749c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2b719893a4d8a1e71857966d399f06c0a41749c.hip new file mode 100644 index 000000000000..f826edcf7a4a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2b719893a4d8a1e71857966d399f06c0a41749c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2f04447e6a94c94a2315454e71d7d607a9fd0f8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2f04447e6a94c94a2315454e71d7d607a9fd0f8.hip new file mode 100644 index 000000000000..3bf5eafe74bd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2f04447e6a94c94a2315454e71d7d607a9fd0f8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2fcced07cc194a8050bc7b2f791453b3f5b2064.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2fcced07cc194a8050bc7b2f791453b3f5b2064.hip new file mode 100644 index 000000000000..ab19c5ef64b2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c2fcced07cc194a8050bc7b2f791453b3f5b2064.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c323a4d1f24d59bddd20ed2f2fb6446627b0ae8b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c323a4d1f24d59bddd20ed2f2fb6446627b0ae8b.hip new file mode 100644 index 000000000000..dc7700e44f33 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c323a4d1f24d59bddd20ed2f2fb6446627b0ae8b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c355189ade9b1a8269230232db754a3881b53168.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c355189ade9b1a8269230232db754a3881b53168.hip new file mode 100644 index 000000000000..61d12216a26b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c355189ade9b1a8269230232db754a3881b53168.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c35ea54eb6cd0f3756c462c66d9be956279b46ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c35ea54eb6cd0f3756c462c66d9be956279b46ad.hip new file mode 100644 index 000000000000..791e9ed6ef1d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c35ea54eb6cd0f3756c462c66d9be956279b46ad.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c363ee1b087f6b504a3dd3972b96e77db02b0582.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c363ee1b087f6b504a3dd3972b96e77db02b0582.hip new file mode 100644 index 000000000000..bba64793b7bd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c363ee1b087f6b504a3dd3972b96e77db02b0582.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c3cfaf0d53869c373f6d0ec821b008dbb819141a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c3cfaf0d53869c373f6d0ec821b008dbb819141a.hip new file mode 100644 index 000000000000..53c8d770ad74 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c3cfaf0d53869c373f6d0ec821b008dbb819141a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c3d0eaf9399c863d672e8c08d123739bab837d4b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c3d0eaf9399c863d672e8c08d123739bab837d4b.hip new file mode 100644 index 000000000000..cf4e6b089f6f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c3d0eaf9399c863d672e8c08d123739bab837d4b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4015f0d0a7a5173810f6f17c00065e03fc61a89.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4015f0d0a7a5173810f6f17c00065e03fc61a89.hip new file mode 100644 index 000000000000..d20fa308e525 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4015f0d0a7a5173810f6f17c00065e03fc61a89.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c402e84359b2037a29efd1d6ce7213ba7605ab25.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c402e84359b2037a29efd1d6ce7213ba7605ab25.hip new file mode 100644 index 000000000000..7e73f68899bc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c402e84359b2037a29efd1d6ce7213ba7605ab25.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c41b6eda4f250da059fe0c428428219ff5a250ef.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c41b6eda4f250da059fe0c428428219ff5a250ef.hip new file mode 100644 index 000000000000..9c7295dd911e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c41b6eda4f250da059fe0c428428219ff5a250ef.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + false, + false, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c42ab428503e8f8bfa78c8cb8d9afad9f5185118.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c42ab428503e8f8bfa78c8cb8d9afad9f5185118.hip new file mode 100644 index 000000000000..836db9f6e86f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c42ab428503e8f8bfa78c8cb8d9afad9f5185118.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4376ac8d82db1bc25fa273a80dfbf8b71ee5e2b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4376ac8d82db1bc25fa273a80dfbf8b71ee5e2b.hip new file mode 100644 index 000000000000..ea220cf459a3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4376ac8d82db1bc25fa273a80dfbf8b71ee5e2b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c45a5e40f6a66bc5292a56e0097c69fe37cedfb3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c45a5e40f6a66bc5292a56e0097c69fe37cedfb3.hip new file mode 100644 index 000000000000..4f64a8af7d8a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c45a5e40f6a66bc5292a56e0097c69fe37cedfb3.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + false, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c487a1a9933239270f44b1e08e1cf5323521c089.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c487a1a9933239270f44b1e08e1cf5323521c089.hip new file mode 100644 index 000000000000..13d2416038ce --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c487a1a9933239270f44b1e08e1cf5323521c089.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4997f79435cf64add10506acb97d0647cfbb3d4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4997f79435cf64add10506acb97d0647cfbb3d4.hip new file mode 100644 index 000000000000..fa1724ec38cc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4997f79435cf64add10506acb97d0647cfbb3d4.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4b34d3cb673447773f6da23e9cf52b98e99f718.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4b34d3cb673447773f6da23e9cf52b98e99f718.hip new file mode 100644 index 000000000000..6af3d4ddbd43 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4b34d3cb673447773f6da23e9cf52b98e99f718.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4c3425fe683d35dc3335db77d183ad1620b7a92.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4c3425fe683d35dc3335db77d183ad1620b7a92.hip new file mode 100644 index 000000000000..36817d57ff4c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4c3425fe683d35dc3335db77d183ad1620b7a92.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4c6c405cefe204824e8fad1b3dd34bba87e796a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4c6c405cefe204824e8fad1b3dd34bba87e796a.hip new file mode 100644 index 000000000000..0ad57e26bb9e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4c6c405cefe204824e8fad1b3dd34bba87e796a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4de1bc135191f3c2aff740f4c6bb7e98da42f84.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4de1bc135191f3c2aff740f4c6bb7e98da42f84.hip new file mode 100644 index 000000000000..020af5402f6f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4de1bc135191f3c2aff740f4c6bb7e98da42f84.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4dec99707511cebd9188d216ee0a148d729b470.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4dec99707511cebd9188d216ee0a148d729b470.hip new file mode 100644 index 000000000000..953d23ed4f0b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c4dec99707511cebd9188d216ee0a148d729b470.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c538dc4f65d02776875627cbd20a9c794d70b043.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c538dc4f65d02776875627cbd20a9c794d70b043.hip new file mode 100644 index 000000000000..1ca6dbbccfd1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c538dc4f65d02776875627cbd20a9c794d70b043.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c53e295b68e807774ed31bb914e4bc59312a77d7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c53e295b68e807774ed31bb914e4bc59312a77d7.hip new file mode 100644 index 000000000000..74c245211839 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c53e295b68e807774ed31bb914e4bc59312a77d7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c56aa150611b0d4800470c1493dc907082a5c23f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c56aa150611b0d4800470c1493dc907082a5c23f.hip new file mode 100644 index 000000000000..88a7b6c0c0b3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c56aa150611b0d4800470c1493dc907082a5c23f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c581974c8b6f43f60d0af29c350d850b55c03121.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c581974c8b6f43f60d0af29c350d850b55c03121.hip new file mode 100644 index 000000000000..1d1d428e5f09 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c581974c8b6f43f60d0af29c350d850b55c03121.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c59937be2b9a13d6520fdcc922e4e75c9fa085ab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c59937be2b9a13d6520fdcc922e4e75c9fa085ab.hip new file mode 100644 index 000000000000..1f709adf33d0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c59937be2b9a13d6520fdcc922e4e75c9fa085ab.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c59a22c6efd8bb8815887325aa0b739e260cc754.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c59a22c6efd8bb8815887325aa0b739e260cc754.hip new file mode 100644 index 000000000000..ccbb66cd88d1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c59a22c6efd8bb8815887325aa0b739e260cc754.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c59ab718fa23f24f09a713ac28a339208a7a5802.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c59ab718fa23f24f09a713ac28a339208a7a5802.hip new file mode 100644 index 000000000000..c8c411a5257c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c59ab718fa23f24f09a713ac28a339208a7a5802.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c5b440ca9a5196ee1e72c878c87d96934e9273c8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c5b440ca9a5196ee1e72c878c87d96934e9273c8.hip new file mode 100644 index 000000000000..6e76b2218d49 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c5b440ca9a5196ee1e72c878c87d96934e9273c8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c5fcdea177734366d3bf283317a65cc3fffda611.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c5fcdea177734366d3bf283317a65cc3fffda611.hip new file mode 100644 index 000000000000..fd9d0e08b9ae --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c5fcdea177734366d3bf283317a65cc3fffda611.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c5fef330a975002ed15670e8e7b26a10376d3cb7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c5fef330a975002ed15670e8e7b26a10376d3cb7.hip new file mode 100644 index 000000000000..722aae67021a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c5fef330a975002ed15670e8e7b26a10376d3cb7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c64f4cdce32189065362a502105c31bd2d9d99a4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c64f4cdce32189065362a502105c31bd2d9d99a4.hip new file mode 100644 index 000000000000..00a9f7a7abb6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c64f4cdce32189065362a502105c31bd2d9d99a4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c6e2da8b791d31f4ba05ef5f833fd6dea9e35f1c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c6e2da8b791d31f4ba05ef5f833fd6dea9e35f1c.hip new file mode 100644 index 000000000000..2585c19fbbcd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c6e2da8b791d31f4ba05ef5f833fd6dea9e35f1c.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + false, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c7568e11e44ce70924d27e683190422cfae5c31d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c7568e11e44ce70924d27e683190422cfae5c31d.hip new file mode 100644 index 000000000000..d34160b1b16b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c7568e11e44ce70924d27e683190422cfae5c31d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c7af2bbfac25de2853be344b9f636226c1c0112d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c7af2bbfac25de2853be344b9f636226c1c0112d.hip new file mode 100644 index 000000000000..1366275e4097 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c7af2bbfac25de2853be344b9f636226c1c0112d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c806d7803d06ef8aac1d5caac9f36aafd47653d5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c806d7803d06ef8aac1d5caac9f36aafd47653d5.hip new file mode 100644 index 000000000000..13d757d9843c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c806d7803d06ef8aac1d5caac9f36aafd47653d5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c80dce1a17d073259250ec0c87ade69e639ffa8e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c80dce1a17d073259250ec0c87ade69e639ffa8e.hip new file mode 100644 index 000000000000..e50fb8502b00 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c80dce1a17d073259250ec0c87ade69e639ffa8e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c8dbfaffc8a9b573f194f9c63f1175d9725f8950.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c8dbfaffc8a9b573f194f9c63f1175d9725f8950.hip new file mode 100644 index 000000000000..9ca6fc02b791 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c8dbfaffc8a9b573f194f9c63f1175d9725f8950.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c8f6461673882d636772ae4d26e78eabcb568f31.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c8f6461673882d636772ae4d26e78eabcb568f31.hip new file mode 100644 index 000000000000..e1be5aea0897 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c8f6461673882d636772ae4d26e78eabcb568f31.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c919b8ed877d4244d01a17ecb948b459e361ff24.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c919b8ed877d4244d01a17ecb948b459e361ff24.hip new file mode 100644 index 000000000000..8d6e7b9ce37b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c919b8ed877d4244d01a17ecb948b459e361ff24.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c921a4790f982d48bcaf950123c699647afb739b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c921a4790f982d48bcaf950123c699647afb739b.hip new file mode 100644 index 000000000000..f22eb9438c26 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c921a4790f982d48bcaf950123c699647afb739b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9312d7159369d13f3148a6f0882dfad6921ceec.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9312d7159369d13f3148a6f0882dfad6921ceec.hip new file mode 100644 index 000000000000..685fdcccebf6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9312d7159369d13f3148a6f0882dfad6921ceec.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9530e20038eb40c49bc8b045be0cf4e7e6b4eac.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9530e20038eb40c49bc8b045be0cf4e7e6b4eac.hip new file mode 100644 index 000000000000..e7862a10c70c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9530e20038eb40c49bc8b045be0cf4e7e6b4eac.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c977735a36c325706bd19a12df66ed0839b032b1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c977735a36c325706bd19a12df66ed0839b032b1.hip new file mode 100644 index 000000000000..fcede89ce04d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c977735a36c325706bd19a12df66ed0839b032b1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9ad71883a19b522486706d3705700c012a6fc19.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9ad71883a19b522486706d3705700c012a6fc19.hip new file mode 100644 index 000000000000..2cc6c54743da --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9ad71883a19b522486706d3705700c012a6fc19.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9ba0a3369d4e4eaea1c902a90e6501f232dd57c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9ba0a3369d4e4eaea1c902a90e6501f232dd57c.hip new file mode 100644 index 000000000000..f6cee564bae8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9ba0a3369d4e4eaea1c902a90e6501f232dd57c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9f1e7e478a2208c4d32e2d7e6abebdc16bcc5fe.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9f1e7e478a2208c4d32e2d7e6abebdc16bcc5fe.hip new file mode 100644 index 000000000000..8d7782c8a78b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9f1e7e478a2208c4d32e2d7e6abebdc16bcc5fe.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9f28230817c9d9805c41dfcd4e834fe302e1df1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9f28230817c9d9805c41dfcd4e834fe302e1df1.hip new file mode 100644 index 000000000000..abef9cc7af64 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9f28230817c9d9805c41dfcd4e834fe302e1df1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9fb8343e623e46f01893a2b61345d1ca5928671.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9fb8343e623e46f01893a2b61345d1ca5928671.hip new file mode 100644 index 000000000000..10259551d95d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9fb8343e623e46f01893a2b61345d1ca5928671.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9fe51f982abd60e567d4238d3266fb60e45814b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9fe51f982abd60e567d4238d3266fb60e45814b.hip new file mode 100644 index 000000000000..68e98abb3da9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_c9fe51f982abd60e567d4238d3266fb60e45814b.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::bf16_t, + true, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca00cfdc5592b7440d72482a18781e9cf3afb05a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca00cfdc5592b7440d72482a18781e9cf3afb05a.hip new file mode 100644 index 000000000000..04562d43e6a4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca00cfdc5592b7440d72482a18781e9cf3afb05a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca1992a2634cd6674076611be54197c715ad8271.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca1992a2634cd6674076611be54197c715ad8271.hip new file mode 100644 index 000000000000..b1a8d1bdc078 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca1992a2634cd6674076611be54197c715ad8271.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca3975efd767ddf7c12e308d948bdcaf0968493a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca3975efd767ddf7c12e308d948bdcaf0968493a.hip new file mode 100644 index 000000000000..b797f8dbee0e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca3975efd767ddf7c12e308d948bdcaf0968493a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca3d98ff43fbb80ceb82fc22ab039bee898969b0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca3d98ff43fbb80ceb82fc22ab039bee898969b0.hip new file mode 100644 index 000000000000..695453631ca6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca3d98ff43fbb80ceb82fc22ab039bee898969b0.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca4c6ad28aff1976c6dd36974ec3b339aa3090e9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca4c6ad28aff1976c6dd36974ec3b339aa3090e9.hip new file mode 100644 index 000000000000..300fb059215d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca4c6ad28aff1976c6dd36974ec3b339aa3090e9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca5681d4e5871aacef74bdba9e368445875252d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca5681d4e5871aacef74bdba9e368445875252d3.hip new file mode 100644 index 000000000000..dd82712c7c81 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca5681d4e5871aacef74bdba9e368445875252d3.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + false, + false, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca920c3239bb5796b1ab2fc75177eb3b820aa784.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca920c3239bb5796b1ab2fc75177eb3b820aa784.hip new file mode 100644 index 000000000000..8c2b34f48840 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ca920c3239bb5796b1ab2fc75177eb3b820aa784.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cabb7b12cdd9b8b522af577e13232b2459dbd38d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cabb7b12cdd9b8b522af577e13232b2459dbd38d.hip new file mode 100644 index 000000000000..b3fda98f7f22 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cabb7b12cdd9b8b522af577e13232b2459dbd38d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cae6c7efbfc831e2bcfc8c1efa1a486c02627cbf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cae6c7efbfc831e2bcfc8c1efa1a486c02627cbf.hip new file mode 100644 index 000000000000..0fe3d7f3ef6c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cae6c7efbfc831e2bcfc8c1efa1a486c02627cbf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_caede7a18f3e3d5e24f6c70392413a2cda16ac15.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_caede7a18f3e3d5e24f6c70392413a2cda16ac15.hip new file mode 100644 index 000000000000..55a357335bb0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_caede7a18f3e3d5e24f6c70392413a2cda16ac15.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb10303a0b79f2710eb7c66896d3c1f8b12c04dd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb10303a0b79f2710eb7c66896d3c1f8b12c04dd.hip new file mode 100644 index 000000000000..5c029304800b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb10303a0b79f2710eb7c66896d3c1f8b12c04dd.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb1a0ce432c27f4cfa51731c3ef181bf60c8a727.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb1a0ce432c27f4cfa51731c3ef181bf60c8a727.hip new file mode 100644 index 000000000000..63387793002b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb1a0ce432c27f4cfa51731c3ef181bf60c8a727.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb1b91c16e0255fe7a0a85638b98d94634e143a9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb1b91c16e0255fe7a0a85638b98d94634e143a9.hip new file mode 100644 index 000000000000..d3ed2d518168 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb1b91c16e0255fe7a0a85638b98d94634e143a9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb1deea4f4fab0db31d46a91228601f0c272d6e6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb1deea4f4fab0db31d46a91228601f0c272d6e6.hip new file mode 100644 index 000000000000..79177c2d6fc2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb1deea4f4fab0db31d46a91228601f0c272d6e6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb20538073888bdb3174a8e9c32d7449072aa753.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb20538073888bdb3174a8e9c32d7449072aa753.hip new file mode 100644 index 000000000000..d6d2f19f4afb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb20538073888bdb3174a8e9c32d7449072aa753.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb3d5273945c5d40cc05c2660af2df1fb7a15f3c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb3d5273945c5d40cc05c2660af2df1fb7a15f3c.hip new file mode 100644 index 000000000000..55519904a497 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb3d5273945c5d40cc05c2660af2df1fb7a15f3c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb4576e8ea5d59d7663f3760009a00a19e1b0667.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb4576e8ea5d59d7663f3760009a00a19e1b0667.hip new file mode 100644 index 000000000000..4d0c9a2b57f5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cb4576e8ea5d59d7663f3760009a00a19e1b0667.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cbd571f4fe576fdb17d5f75a558cb6747087c7f2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cbd571f4fe576fdb17d5f75a558cb6747087c7f2.hip new file mode 100644 index 000000000000..9ab334b4303f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cbd571f4fe576fdb17d5f75a558cb6747087c7f2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cbe5a98163e878c7697e554758ebd0597c2c1760.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cbe5a98163e878c7697e554758ebd0597c2c1760.hip new file mode 100644 index 000000000000..ebfb54a32255 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cbe5a98163e878c7697e554758ebd0597c2c1760.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cbf3e4d4d4837a0cb33b78c4f2767b1d93da0850.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cbf3e4d4d4837a0cb33b78c4f2767b1d93da0850.hip new file mode 100644 index 000000000000..84658a15c4fd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cbf3e4d4d4837a0cb33b78c4f2767b1d93da0850.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cc127a63d56099e08125b16939dac82f0173122b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cc127a63d56099e08125b16939dac82f0173122b.hip new file mode 100644 index 000000000000..d40da5d5b20c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cc127a63d56099e08125b16939dac82f0173122b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cc4ac5a18f57f2ebb65f7e356e858ab0d59b2133.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cc4ac5a18f57f2ebb65f7e356e858ab0d59b2133.hip new file mode 100644 index 000000000000..6c4df694544b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cc4ac5a18f57f2ebb65f7e356e858ab0d59b2133.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cc54b107e1b557ea36b5cbaf7fe3dfce05415c86.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cc54b107e1b557ea36b5cbaf7fe3dfce05415c86.hip new file mode 100644 index 000000000000..1639ece1b20f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cc54b107e1b557ea36b5cbaf7fe3dfce05415c86.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ccac6c0e61b65c9422c7f30fbd979031698370a9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ccac6c0e61b65c9422c7f30fbd979031698370a9.hip new file mode 100644 index 000000000000..b7a2e3905d5b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ccac6c0e61b65c9422c7f30fbd979031698370a9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ccd0b777df1328bf24e070ed4cdf8615bb2199fe.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ccd0b777df1328bf24e070ed4cdf8615bb2199fe.hip new file mode 100644 index 000000000000..39bdcc20a8b1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ccd0b777df1328bf24e070ed4cdf8615bb2199fe.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cd0453a5c3828c1358360f31f5d3b7258e17fdb9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cd0453a5c3828c1358360f31f5d3b7258e17fdb9.hip new file mode 100644 index 000000000000..51260dc0ce9e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cd0453a5c3828c1358360f31f5d3b7258e17fdb9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cd4efcdd12184211c74e7b3f2f30fecf1041ca32.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cd4efcdd12184211c74e7b3f2f30fecf1041ca32.hip new file mode 100644 index 000000000000..78137a331f01 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cd4efcdd12184211c74e7b3f2f30fecf1041ca32.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cd757a8bbeabd16a44d149ab188430f6d79ddcaf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cd757a8bbeabd16a44d149ab188430f6d79ddcaf.hip new file mode 100644 index 000000000000..eb744ff32796 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cd757a8bbeabd16a44d149ab188430f6d79ddcaf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cde0582e1aef74f9209de638b553ec0671476258.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cde0582e1aef74f9209de638b553ec0671476258.hip new file mode 100644 index 000000000000..2307064b178a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cde0582e1aef74f9209de638b553ec0671476258.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce4714e4f33340859c106a3129993e22652262e2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce4714e4f33340859c106a3129993e22652262e2.hip new file mode 100644 index 000000000000..32a84284387f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce4714e4f33340859c106a3129993e22652262e2.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5064e27ba427cb951f7e1b01328b0beb6b2b7c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5064e27ba427cb951f7e1b01328b0beb6b2b7c.hip new file mode 100644 index 000000000000..9ce46029f56b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5064e27ba427cb951f7e1b01328b0beb6b2b7c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5ad502dd40353312d561e9f40aa478c16ef5b1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5ad502dd40353312d561e9f40aa478c16ef5b1.hip new file mode 100644 index 000000000000..77d65db91d85 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5ad502dd40353312d561e9f40aa478c16ef5b1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5b5932f6df9a194ceb0d69220fba9596528eec.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5b5932f6df9a194ceb0d69220fba9596528eec.hip new file mode 100644 index 000000000000..0411dfdd8e73 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5b5932f6df9a194ceb0d69220fba9596528eec.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5c161b725becf059fb4439c668edd454ac77d1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5c161b725becf059fb4439c668edd454ac77d1.hip new file mode 100644 index 000000000000..051745befa82 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce5c161b725becf059fb4439c668edd454ac77d1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce909cb5f96a4884caa0d2eb8c5e6bc7fa352797.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce909cb5f96a4884caa0d2eb8c5e6bc7fa352797.hip new file mode 100644 index 000000000000..951758448779 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ce909cb5f96a4884caa0d2eb8c5e6bc7fa352797.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ceb9544e2a0caae2c9e3dd8bbd2c509e8dca1379.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ceb9544e2a0caae2c9e3dd8bbd2c509e8dca1379.hip new file mode 100644 index 000000000000..e9b81569dce8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ceb9544e2a0caae2c9e3dd8bbd2c509e8dca1379.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cee81ab2e2678816c7b516d2d4c50e8cb5874c68.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cee81ab2e2678816c7b516d2d4c50e8cb5874c68.hip new file mode 100644 index 000000000000..ac3fce28609f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cee81ab2e2678816c7b516d2d4c50e8cb5874c68.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cf5c6c0bfaf98f6e655fc443246b81fcc730fe97.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cf5c6c0bfaf98f6e655fc443246b81fcc730fe97.hip new file mode 100644 index 000000000000..940795ee7ce0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cf5c6c0bfaf98f6e655fc443246b81fcc730fe97.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cf73e1fc0015094861ca0c1c81bacdbe0c5b8f37.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cf73e1fc0015094861ca0c1c81bacdbe0c5b8f37.hip new file mode 100644 index 000000000000..237f75bd36ca --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cf73e1fc0015094861ca0c1c81bacdbe0c5b8f37.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cfda56a4eb08b803332f25bda6209932d9624acc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cfda56a4eb08b803332f25bda6209932d9624acc.hip new file mode 100644 index 000000000000..c9500a1d03de --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cfda56a4eb08b803332f25bda6209932d9624acc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cfec97bdfb6fa95e057eaf5a8138853e1c0884f2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cfec97bdfb6fa95e057eaf5a8138853e1c0884f2.hip new file mode 100644 index 000000000000..3a213b6a76dd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_cfec97bdfb6fa95e057eaf5a8138853e1c0884f2.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d00f65bc99ca08eba66564d34f72f2769bff9491.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d00f65bc99ca08eba66564d34f72f2769bff9491.hip new file mode 100644 index 000000000000..526625594bc5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d00f65bc99ca08eba66564d34f72f2769bff9491.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d036096f49a89730f8af7e75457c88cb8ae64165.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d036096f49a89730f8af7e75457c88cb8ae64165.hip new file mode 100644 index 000000000000..155d14f6df70 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d036096f49a89730f8af7e75457c88cb8ae64165.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d049a1b8f4c1c6d37973ce38593efda1de8ce0cd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d049a1b8f4c1c6d37973ce38593efda1de8ce0cd.hip new file mode 100644 index 000000000000..3f90311218b6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d049a1b8f4c1c6d37973ce38593efda1de8ce0cd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d04dc4ed02eb42c3fe303342801ed3073a0dcb8e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d04dc4ed02eb42c3fe303342801ed3073a0dcb8e.hip new file mode 100644 index 000000000000..8c9be1eefe11 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d04dc4ed02eb42c3fe303342801ed3073a0dcb8e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d06ba4c996570ddab77b6ff1e2a0101b638543eb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d06ba4c996570ddab77b6ff1e2a0101b638543eb.hip new file mode 100644 index 000000000000..e38d6adf2ba7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d06ba4c996570ddab77b6ff1e2a0101b638543eb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0863830fc5d43dc6d6400280e892bb7de2892d4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0863830fc5d43dc6d6400280e892bb7de2892d4.hip new file mode 100644 index 000000000000..21d521abdcd1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0863830fc5d43dc6d6400280e892bb7de2892d4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d090b771a4f9750132f549c82a88b4ab00dce5c7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d090b771a4f9750132f549c82a88b4ab00dce5c7.hip new file mode 100644 index 000000000000..abb7de42e88c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d090b771a4f9750132f549c82a88b4ab00dce5c7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0b09e8513646fbb2a007544a63ec9e2b04dc4c2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0b09e8513646fbb2a007544a63ec9e2b04dc4c2.hip new file mode 100644 index 000000000000..80b87189a9bc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0b09e8513646fbb2a007544a63ec9e2b04dc4c2.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0daa59f5dce6fc3965193ae37d8c82a3d1834e6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0daa59f5dce6fc3965193ae37d8c82a3d1834e6.hip new file mode 100644 index 000000000000..e973fe505b70 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0daa59f5dce6fc3965193ae37d8c82a3d1834e6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0dd0165ee91c095a19ceddf08789e3576912590.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0dd0165ee91c095a19ceddf08789e3576912590.hip new file mode 100644 index 000000000000..69856a46bb5a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0dd0165ee91c095a19ceddf08789e3576912590.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0de618ff3ea9f67b90f2227fb7fcc74ea34183d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0de618ff3ea9f67b90f2227fb7fcc74ea34183d.hip new file mode 100644 index 000000000000..72ec5b33008b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0de618ff3ea9f67b90f2227fb7fcc74ea34183d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0f63cafbeb445408c884727b473667fb479675e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0f63cafbeb445408c884727b473667fb479675e.hip new file mode 100644 index 000000000000..190762434427 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d0f63cafbeb445408c884727b473667fb479675e.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d137b7b6e04e1caf43a62bd6788a75361cfa98f6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d137b7b6e04e1caf43a62bd6788a75361cfa98f6.hip new file mode 100644 index 000000000000..de13791a91fa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d137b7b6e04e1caf43a62bd6788a75361cfa98f6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1840494c4fa78ff399c0399b3ad7ca3d22d4587.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1840494c4fa78ff399c0399b3ad7ca3d22d4587.hip new file mode 100644 index 000000000000..0b9b5772cd6b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1840494c4fa78ff399c0399b3ad7ca3d22d4587.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d18727988e47264b42b4153dc82fc1a750f08db0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d18727988e47264b42b4153dc82fc1a750f08db0.hip new file mode 100644 index 000000000000..4d02471bf8f2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d18727988e47264b42b4153dc82fc1a750f08db0.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1c0dfd19a08d61586758091370acbdc6f267017.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1c0dfd19a08d61586758091370acbdc6f267017.hip new file mode 100644 index 000000000000..f3dc0e43f13d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1c0dfd19a08d61586758091370acbdc6f267017.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1c25cfc437d8bd803860e39a45b2f3b9fa48393.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1c25cfc437d8bd803860e39a45b2f3b9fa48393.hip new file mode 100644 index 000000000000..5e557689795f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1c25cfc437d8bd803860e39a45b2f3b9fa48393.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1d3eacc320104100bce46235fe656e5a8223c66.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1d3eacc320104100bce46235fe656e5a8223c66.hip new file mode 100644 index 000000000000..779433cfd7ea --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d1d3eacc320104100bce46235fe656e5a8223c66.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d20d45aa85c0daa299da98c277cee826fe67bd27.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d20d45aa85c0daa299da98c277cee826fe67bd27.hip new file mode 100644 index 000000000000..8f05e1243fd5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d20d45aa85c0daa299da98c277cee826fe67bd27.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d257148f457557ea80ca56690e525db3a4b0ff55.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d257148f457557ea80ca56690e525db3a4b0ff55.hip new file mode 100644 index 000000000000..5dba5da8184d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d257148f457557ea80ca56690e525db3a4b0ff55.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + false, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d25ce4b3e9cc392ceafebc7fe3bcbe05aaad4bbc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d25ce4b3e9cc392ceafebc7fe3bcbe05aaad4bbc.hip new file mode 100644 index 000000000000..28fa3b5d1cfe --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d25ce4b3e9cc392ceafebc7fe3bcbe05aaad4bbc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2d08c5470a385d0160b2c1441fd1c30fff1c17c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2d08c5470a385d0160b2c1441fd1c30fff1c17c.hip new file mode 100644 index 000000000000..a7555228dfa4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2d08c5470a385d0160b2c1441fd1c30fff1c17c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2daccc4b3a0f90bff39cb4597f8b7e484613d9e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2daccc4b3a0f90bff39cb4597f8b7e484613d9e.hip new file mode 100644 index 000000000000..45a37392130f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2daccc4b3a0f90bff39cb4597f8b7e484613d9e.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2dfdb42c1b380e860aa5609302f29698dd27923.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2dfdb42c1b380e860aa5609302f29698dd27923.hip new file mode 100644 index 000000000000..ac951973b6c9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2dfdb42c1b380e860aa5609302f29698dd27923.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2f4b869ff23874b6bde0aab68c419108b7e69f4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2f4b869ff23874b6bde0aab68c419108b7e69f4.hip new file mode 100644 index 000000000000..f8a749915303 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d2f4b869ff23874b6bde0aab68c419108b7e69f4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d32c64ef01aa228277d031a74df51363f98aa2b0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d32c64ef01aa228277d031a74df51363f98aa2b0.hip new file mode 100644 index 000000000000..13d7502c3a17 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d32c64ef01aa228277d031a74df51363f98aa2b0.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + false, + false, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d34d6cdcd81a456125ab5e0875466c6334d8e5c8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d34d6cdcd81a456125ab5e0875466c6334d8e5c8.hip new file mode 100644 index 000000000000..9654803a8750 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d34d6cdcd81a456125ab5e0875466c6334d8e5c8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d34fcb56caa8f80404789fba0ffac447483a4d84.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d34fcb56caa8f80404789fba0ffac447483a4d84.hip new file mode 100644 index 000000000000..9d12efc0749f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d34fcb56caa8f80404789fba0ffac447483a4d84.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3784fb4c0685d7b651f4113f3c71e050881f3a5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3784fb4c0685d7b651f4113f3c71e050881f3a5.hip new file mode 100644 index 000000000000..de697b2f2c11 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3784fb4c0685d7b651f4113f3c71e050881f3a5.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3a23ded424200d0c6f06b1dbd0a7b7b0e7b5d9b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3a23ded424200d0c6f06b1dbd0a7b7b0e7b5d9b.hip new file mode 100644 index 000000000000..2cc4d9d90a25 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3a23ded424200d0c6f06b1dbd0a7b7b0e7b5d9b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3a2edf232786d458e2125f8dfeda8847f842afa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3a2edf232786d458e2125f8dfeda8847f842afa.hip new file mode 100644 index 000000000000..27aca3a502c7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3a2edf232786d458e2125f8dfeda8847f842afa.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3af8763f289dace1054bdcb4dfeda28b0aefcae.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3af8763f289dace1054bdcb4dfeda28b0aefcae.hip new file mode 100644 index 000000000000..b6b582238192 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3af8763f289dace1054bdcb4dfeda28b0aefcae.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + true, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3fce1e11aee2273620e75efe4aa0390fcde9ba5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3fce1e11aee2273620e75efe4aa0390fcde9ba5.hip new file mode 100644 index 000000000000..c2d920704a67 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d3fce1e11aee2273620e75efe4aa0390fcde9ba5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d40569ae9dbd693c0ab3d6ba69704d31e451011b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d40569ae9dbd693c0ab3d6ba69704d31e451011b.hip new file mode 100644 index 000000000000..e3e7d8ca7231 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d40569ae9dbd693c0ab3d6ba69704d31e451011b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d41b6a64dd181f2efa65aaed03a3d229b3566c1d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d41b6a64dd181f2efa65aaed03a3d229b3566c1d.hip new file mode 100644 index 000000000000..a6e79f9bf0c2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d41b6a64dd181f2efa65aaed03a3d229b3566c1d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d41cd6b60a97e7071518cbd1a63abb8b910df024.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d41cd6b60a97e7071518cbd1a63abb8b910df024.hip new file mode 100644 index 000000000000..b1609b190fa1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d41cd6b60a97e7071518cbd1a63abb8b910df024.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d43715cce8935439f90172d141050d78c7e76fb7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d43715cce8935439f90172d141050d78c7e76fb7.hip new file mode 100644 index 000000000000..8173e29193cd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d43715cce8935439f90172d141050d78c7e76fb7.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4605b2ad3e3753c5f255678abc1690b949c5abc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4605b2ad3e3753c5f255678abc1690b949c5abc.hip new file mode 100644 index 000000000000..6cadf5fbdd6e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4605b2ad3e3753c5f255678abc1690b949c5abc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4645b713821371161a9925dec8a3d6c157ba1aa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4645b713821371161a9925dec8a3d6c157ba1aa.hip new file mode 100644 index 000000000000..dd32abcd1e4d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4645b713821371161a9925dec8a3d6c157ba1aa.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4aff499ad527be5fe33b8e92547df57af26d40d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4aff499ad527be5fe33b8e92547df57af26d40d.hip new file mode 100644 index 000000000000..3c007c9e3723 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4aff499ad527be5fe33b8e92547df57af26d40d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4b99af9a573df50a27fccbec3fa8e350f1854eb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4b99af9a573df50a27fccbec3fa8e350f1854eb.hip new file mode 100644 index 000000000000..0244535a25a3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4b99af9a573df50a27fccbec3fa8e350f1854eb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4c9f975891087e6eed6393629b41155deafc509.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4c9f975891087e6eed6393629b41155deafc509.hip new file mode 100644 index 000000000000..b755f2652f62 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d4c9f975891087e6eed6393629b41155deafc509.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d50ac8e8a03f8e7ec2c6e993dd39f09f465dab57.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d50ac8e8a03f8e7ec2c6e993dd39f09f465dab57.hip new file mode 100644 index 000000000000..bfe25191b322 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d50ac8e8a03f8e7ec2c6e993dd39f09f465dab57.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d54ac01458df3f240e0656d82330f9de23ba9651.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d54ac01458df3f240e0656d82330f9de23ba9651.hip new file mode 100644 index 000000000000..1580ca2a5d27 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d54ac01458df3f240e0656d82330f9de23ba9651.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d54b3731883a5f8393d60d27487f8d017aedd3f9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d54b3731883a5f8393d60d27487f8d017aedd3f9.hip new file mode 100644 index 000000000000..5e6c8d143420 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d54b3731883a5f8393d60d27487f8d017aedd3f9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d5e82799f4452e148c3e02acd6526cf30757eb52.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d5e82799f4452e148c3e02acd6526cf30757eb52.hip new file mode 100644 index 000000000000..6e06047e93af --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d5e82799f4452e148c3e02acd6526cf30757eb52.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d5edfe3e3dc3008b928c8e6dbd50784b905f189e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d5edfe3e3dc3008b928c8e6dbd50784b905f189e.hip new file mode 100644 index 000000000000..44a58feba4d9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d5edfe3e3dc3008b928c8e6dbd50784b905f189e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d600779c17b7b21c18e1308e6d765fe02a7945d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d600779c17b7b21c18e1308e6d765fe02a7945d3.hip new file mode 100644 index 000000000000..06104fd7b834 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d600779c17b7b21c18e1308e6d765fe02a7945d3.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d6149eea92f2c40c11de3b778102fcf9b6a006b8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d6149eea92f2c40c11de3b778102fcf9b6a006b8.hip new file mode 100644 index 000000000000..67d05f1be2d8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d6149eea92f2c40c11de3b778102fcf9b6a006b8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d623b36cc3f56d1001b2d3abadd8a5628fefd014.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d623b36cc3f56d1001b2d3abadd8a5628fefd014.hip new file mode 100644 index 000000000000..4aac45ef8387 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d623b36cc3f56d1001b2d3abadd8a5628fefd014.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d63c8c746055851217a514321cd735eaf6937263.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d63c8c746055851217a514321cd735eaf6937263.hip new file mode 100644 index 000000000000..364b252ab1a8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d63c8c746055851217a514321cd735eaf6937263.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d64b8b52f4a98801e185e2f132b2f80c29dd0c37.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d64b8b52f4a98801e185e2f132b2f80c29dd0c37.hip new file mode 100644 index 000000000000..7f9835a1a67e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d64b8b52f4a98801e185e2f132b2f80c29dd0c37.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d66b79c4ebdcfd239cecec58203606bc123bd6bb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d66b79c4ebdcfd239cecec58203606bc123bd6bb.hip new file mode 100644 index 000000000000..309f85c7f833 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d66b79c4ebdcfd239cecec58203606bc123bd6bb.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d66c30148a6fa816937f2f095802264d3dfa0273.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d66c30148a6fa816937f2f095802264d3dfa0273.hip new file mode 100644 index 000000000000..681e27343833 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d66c30148a6fa816937f2f095802264d3dfa0273.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d703eea8075cacec4d41fee7dc4734f593ee79e8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d703eea8075cacec4d41fee7dc4734f593ee79e8.hip new file mode 100644 index 000000000000..911884f18d6a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d703eea8075cacec4d41fee7dc4734f593ee79e8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d712f23ef88ae5d7b161d36f42d22a5ba53b6354.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d712f23ef88ae5d7b161d36f42d22a5ba53b6354.hip new file mode 100644 index 000000000000..735e3461e03e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d712f23ef88ae5d7b161d36f42d22a5ba53b6354.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d713fe25dc90b3511fc259cebf463376dcb55d84.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d713fe25dc90b3511fc259cebf463376dcb55d84.hip new file mode 100644 index 000000000000..f31a6366ab5f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d713fe25dc90b3511fc259cebf463376dcb55d84.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7145383e39dec0e346b5094401acf85ef3c2075.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7145383e39dec0e346b5094401acf85ef3c2075.hip new file mode 100644 index 000000000000..da226878f4a9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7145383e39dec0e346b5094401acf85ef3c2075.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d723b191785c97d284675f700a7baeb52a2eb791.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d723b191785c97d284675f700a7baeb52a2eb791.hip new file mode 100644 index 000000000000..2b519281fc38 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d723b191785c97d284675f700a7baeb52a2eb791.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7290cc4c3036c9205e689cbcc60e7d16b97a7d6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7290cc4c3036c9205e689cbcc60e7d16b97a7d6.hip new file mode 100644 index 000000000000..6e496a908dc2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7290cc4c3036c9205e689cbcc60e7d16b97a7d6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d733f4c03e338ea7c6d8f759c1132499bdcea059.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d733f4c03e338ea7c6d8f759c1132499bdcea059.hip new file mode 100644 index 000000000000..5bf0ce3079aa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d733f4c03e338ea7c6d8f759c1132499bdcea059.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d773df9ccfc1ace90fe3afb5c00976deabedf6f8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d773df9ccfc1ace90fe3afb5c00976deabedf6f8.hip new file mode 100644 index 000000000000..6f5ddf20cbf7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d773df9ccfc1ace90fe3afb5c00976deabedf6f8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7adde8780b39f1364c572a19c3bfb19417678e3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7adde8780b39f1364c572a19c3bfb19417678e3.hip new file mode 100644 index 000000000000..f0e55dd85db1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7adde8780b39f1364c572a19c3bfb19417678e3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7bda8157fb27d544e049fd7d2ec735725f1bf44.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7bda8157fb27d544e049fd7d2ec735725f1bf44.hip new file mode 100644 index 000000000000..01bac69a55bb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7bda8157fb27d544e049fd7d2ec735725f1bf44.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7fae2c18645d36a181a0bdd2d8ca7a4ac0f6d1d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7fae2c18645d36a181a0bdd2d8ca7a4ac0f6d1d.hip new file mode 100644 index 000000000000..8fcd9523af0f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d7fae2c18645d36a181a0bdd2d8ca7a4ac0f6d1d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d82773721479613ad72e334510a248f1436b38d6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d82773721479613ad72e334510a248f1436b38d6.hip new file mode 100644 index 000000000000..de30f89448d2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d82773721479613ad72e334510a248f1436b38d6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d867098db97b3f26e71a151c63b74260bfab21f8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d867098db97b3f26e71a151c63b74260bfab21f8.hip new file mode 100644 index 000000000000..6e8215d41b50 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d867098db97b3f26e71a151c63b74260bfab21f8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d86e4dcbe9c4cac8f7c8c5d97ce384ae0cbdbfbc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d86e4dcbe9c4cac8f7c8c5d97ce384ae0cbdbfbc.hip new file mode 100644 index 000000000000..854b06d1b4b8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d86e4dcbe9c4cac8f7c8c5d97ce384ae0cbdbfbc.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d8901a63986cc28ef24cab012b32114851a8c1ec.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d8901a63986cc28ef24cab012b32114851a8c1ec.hip new file mode 100644 index 000000000000..468d51dcd22a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d8901a63986cc28ef24cab012b32114851a8c1ec.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d9061c204d8a85c974676f4438994a0be9d69a60.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d9061c204d8a85c974676f4438994a0be9d69a60.hip new file mode 100644 index 000000000000..50210fe32ef8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d9061c204d8a85c974676f4438994a0be9d69a60.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d924ee32b178b6bffa7a71603d6e2818f66177a5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d924ee32b178b6bffa7a71603d6e2818f66177a5.hip new file mode 100644 index 000000000000..77cb2deae9a9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d924ee32b178b6bffa7a71603d6e2818f66177a5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d937609afa8e21a761dad6b01ff3f26346e450fc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d937609afa8e21a761dad6b01ff3f26346e450fc.hip new file mode 100644 index 000000000000..9c64d867fe5f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d937609afa8e21a761dad6b01ff3f26346e450fc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d95835bc6f000d3a3379bbc38d90e83dcaf867ee.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d95835bc6f000d3a3379bbc38d90e83dcaf867ee.hip new file mode 100644 index 000000000000..f2caaa8f0d02 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d95835bc6f000d3a3379bbc38d90e83dcaf867ee.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d992eab7de49033f5480c5e86a69e675db0d2a19.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d992eab7de49033f5480c5e86a69e675db0d2a19.hip new file mode 100644 index 000000000000..84c378952ce3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d992eab7de49033f5480c5e86a69e675db0d2a19.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d9c23b7f8fcc4e4f4c81f5f00cfd345b98df2e0f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d9c23b7f8fcc4e4f4c81f5f00cfd345b98df2e0f.hip new file mode 100644 index 000000000000..1c2886477b71 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d9c23b7f8fcc4e4f4c81f5f00cfd345b98df2e0f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d9c3e27b522320dcca5ee84fa534b03aae2bfea9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d9c3e27b522320dcca5ee84fa534b03aae2bfea9.hip new file mode 100644 index 000000000000..d93a3dab0059 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_d9c3e27b522320dcca5ee84fa534b03aae2bfea9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da07d8b5666423da30a95e3b2cabd3839d200981.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da07d8b5666423da30a95e3b2cabd3839d200981.hip new file mode 100644 index 000000000000..ec6495c8bb0d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da07d8b5666423da30a95e3b2cabd3839d200981.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da29a515d14dac02066bcd4701285b9916b43cf5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da29a515d14dac02066bcd4701285b9916b43cf5.hip new file mode 100644 index 000000000000..badef10c6564 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da29a515d14dac02066bcd4701285b9916b43cf5.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da6afccdee4107507a64323e17bf12c46da2b92a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da6afccdee4107507a64323e17bf12c46da2b92a.hip new file mode 100644 index 000000000000..504de596f41d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da6afccdee4107507a64323e17bf12c46da2b92a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da74887afedbd67928fe4d596709f9ff92530611.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da74887afedbd67928fe4d596709f9ff92530611.hip new file mode 100644 index 000000000000..c1447654d3b0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da74887afedbd67928fe4d596709f9ff92530611.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da822ea727fb3543e445e4000f7e6ebb946d6a3b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da822ea727fb3543e445e4000f7e6ebb946d6a3b.hip new file mode 100644 index 000000000000..4b6115fcd04b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da822ea727fb3543e445e4000f7e6ebb946d6a3b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da9f6e1d59132fe96709490af25bd794f267851c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da9f6e1d59132fe96709490af25bd794f267851c.hip new file mode 100644 index 000000000000..fdb8da65bbe2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_da9f6e1d59132fe96709490af25bd794f267851c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db0d0cf55d90b3f3c9eecada1db93c420f34b1ae.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db0d0cf55d90b3f3c9eecada1db93c420f34b1ae.hip new file mode 100644 index 000000000000..de7d4a18f0ec --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db0d0cf55d90b3f3c9eecada1db93c420f34b1ae.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db5016bff9e5dc37184d2b9417eb351c7ea1c322.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db5016bff9e5dc37184d2b9417eb351c7ea1c322.hip new file mode 100644 index 000000000000..d2fa64da9e9e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db5016bff9e5dc37184d2b9417eb351c7ea1c322.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db85839ee8d464c5a81b8dad9839f5e0f4b467a8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db85839ee8d464c5a81b8dad9839f5e0f4b467a8.hip new file mode 100644 index 000000000000..2d86919fe422 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db85839ee8d464c5a81b8dad9839f5e0f4b467a8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db8f0bd93b352d28c5b6d78f4332026993f0bea4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db8f0bd93b352d28c5b6d78f4332026993f0bea4.hip new file mode 100644 index 000000000000..3118085b9505 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_db8f0bd93b352d28c5b6d78f4332026993f0bea4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbae1670fac6812b2d2cbad973e4b475509ea504.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbae1670fac6812b2d2cbad973e4b475509ea504.hip new file mode 100644 index 000000000000..97c43eac2511 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbae1670fac6812b2d2cbad973e4b475509ea504.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbb06b43d5d65429e23cc717448cf1fffb0cfd74.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbb06b43d5d65429e23cc717448cf1fffb0cfd74.hip new file mode 100644 index 000000000000..e76974e6def3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbb06b43d5d65429e23cc717448cf1fffb0cfd74.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbc4135fce01e8731fec7a78d0cc0fdeeae28b90.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbc4135fce01e8731fec7a78d0cc0fdeeae28b90.hip new file mode 100644 index 000000000000..63d99059cecd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbc4135fce01e8731fec7a78d0cc0fdeeae28b90.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbcea8f7b5930abf76eecefce92d0db785d2df5d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbcea8f7b5930abf76eecefce92d0db785d2df5d.hip new file mode 100644 index 000000000000..7f0922b1a0f8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbcea8f7b5930abf76eecefce92d0db785d2df5d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbde2ef18e2174ebe13a6e7c8c2a6b05a6612047.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbde2ef18e2174ebe13a6e7c8c2a6b05a6612047.hip new file mode 100644 index 000000000000..606feadb0e71 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dbde2ef18e2174ebe13a6e7c8c2a6b05a6612047.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc039d422a57c159ea4dbcc867d766ff1b356a07.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc039d422a57c159ea4dbcc867d766ff1b356a07.hip new file mode 100644 index 000000000000..64d84451ed80 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc039d422a57c159ea4dbcc867d766ff1b356a07.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc08afbff5def8bcb4e823657ce01f57c9dc77c9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc08afbff5def8bcb4e823657ce01f57c9dc77c9.hip new file mode 100644 index 000000000000..35bc73101c25 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc08afbff5def8bcb4e823657ce01f57c9dc77c9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc184767d723f4995791848cdc68bd948408204f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc184767d723f4995791848cdc68bd948408204f.hip new file mode 100644 index 000000000000..dad0aefdc8e7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc184767d723f4995791848cdc68bd948408204f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc1a7f9b1afeba6690fdc0d0d1755ea89c805573.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc1a7f9b1afeba6690fdc0d0d1755ea89c805573.hip new file mode 100644 index 000000000000..7ef4bb17825b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc1a7f9b1afeba6690fdc0d0d1755ea89c805573.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc34b6ef496d4e0d8fbbe10731d4a7b1c136c036.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc34b6ef496d4e0d8fbbe10731d4a7b1c136c036.hip new file mode 100644 index 000000000000..067462626301 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc34b6ef496d4e0d8fbbe10731d4a7b1c136c036.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc3d625c5ad3e871f5a727ac946df642d988b9ab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc3d625c5ad3e871f5a727ac946df642d988b9ab.hip new file mode 100644 index 000000000000..4330c6ede23d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc3d625c5ad3e871f5a727ac946df642d988b9ab.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc4d27535b9570b8f4b790470a83c1d0a9a2b6ce.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc4d27535b9570b8f4b790470a83c1d0a9a2b6ce.hip new file mode 100644 index 000000000000..a8387a6b8594 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc4d27535b9570b8f4b790470a83c1d0a9a2b6ce.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc5ba6d73f331c76e696953606c5b347b6a46f3f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc5ba6d73f331c76e696953606c5b347b6a46f3f.hip new file mode 100644 index 000000000000..f023946e8d6d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc5ba6d73f331c76e696953606c5b347b6a46f3f.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc62a8db637d32e7dfdb2521cbdae6e1fbbd5fd1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc62a8db637d32e7dfdb2521cbdae6e1fbbd5fd1.hip new file mode 100644 index 000000000000..0f2fd246d49c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc62a8db637d32e7dfdb2521cbdae6e1fbbd5fd1.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc818f3ce244743cb1dbff9aca399df90742a6d0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc818f3ce244743cb1dbff9aca399df90742a6d0.hip new file mode 100644 index 000000000000..12eb566a4a90 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc818f3ce244743cb1dbff9aca399df90742a6d0.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc91797c1474a368e9cb056b50b4629d7736c3cb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc91797c1474a368e9cb056b50b4629d7736c3cb.hip new file mode 100644 index 000000000000..8581e0ed61b7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc91797c1474a368e9cb056b50b4629d7736c3cb.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc9e54273c0ea2358fb573a7d918aa7b09fe07f9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc9e54273c0ea2358fb573a7d918aa7b09fe07f9.hip new file mode 100644 index 000000000000..b5fb86f5f564 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dc9e54273c0ea2358fb573a7d918aa7b09fe07f9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dcf815ef540060cc7ed43e1c57a28e1d080c5621.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dcf815ef540060cc7ed43e1c57a28e1d080c5621.hip new file mode 100644 index 000000000000..042fc02f2077 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dcf815ef540060cc7ed43e1c57a28e1d080c5621.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd10bbf37503bbc92af82bc3487989b41b20ca85.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd10bbf37503bbc92af82bc3487989b41b20ca85.hip new file mode 100644 index 000000000000..e0285cef988f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd10bbf37503bbc92af82bc3487989b41b20ca85.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd11806cd2d3ef1127f676b2d98bf8fff2a1e5ab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd11806cd2d3ef1127f676b2d98bf8fff2a1e5ab.hip new file mode 100644 index 000000000000..7c40cc29684f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd11806cd2d3ef1127f676b2d98bf8fff2a1e5ab.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd35634440edb25cb095800b882c70aaceca1dbb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd35634440edb25cb095800b882c70aaceca1dbb.hip new file mode 100644 index 000000000000..84a280472358 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd35634440edb25cb095800b882c70aaceca1dbb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd67d442001d2b167e70e8730abde4d4461b8569.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd67d442001d2b167e70e8730abde4d4461b8569.hip new file mode 100644 index 000000000000..1db14e6f717a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd67d442001d2b167e70e8730abde4d4461b8569.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd9494d9ac35eba6794a4f9120d2db9932596ef8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd9494d9ac35eba6794a4f9120d2db9932596ef8.hip new file mode 100644 index 000000000000..08a8e86c5123 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dd9494d9ac35eba6794a4f9120d2db9932596ef8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, true,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dda8d021381083bc48b7fb1840729254dd8e5137.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dda8d021381083bc48b7fb1840729254dd8e5137.hip new file mode 100644 index 000000000000..a15cc8532797 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dda8d021381083bc48b7fb1840729254dd8e5137.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ddcb1cfea1b0dbe50a02252cba99428fd977527e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ddcb1cfea1b0dbe50a02252cba99428fd977527e.hip new file mode 100644 index 000000000000..206ee07ae745 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ddcb1cfea1b0dbe50a02252cba99428fd977527e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dde93ffe7fca311e136e42fbcd12b05c9fc7174c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dde93ffe7fca311e136e42fbcd12b05c9fc7174c.hip new file mode 100644 index 000000000000..5103e37e5cdf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dde93ffe7fca311e136e42fbcd12b05c9fc7174c.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ddf5339054f47d9ed6cc7f9e66ab21ce3bccf3db.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ddf5339054f47d9ed6cc7f9e66ab21ce3bccf3db.hip new file mode 100644 index 000000000000..f580b6bc6ab8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ddf5339054f47d9ed6cc7f9e66ab21ce3bccf3db.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de1ff66d2aeb47d2fdccaa4bb6b9d066b380c99e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de1ff66d2aeb47d2fdccaa4bb6b9d066b380c99e.hip new file mode 100644 index 000000000000..0849a2c6380f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de1ff66d2aeb47d2fdccaa4bb6b9d066b380c99e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de26a187c4db06115072a5132e1166b5b03368b0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de26a187c4db06115072a5132e1166b5b03368b0.hip new file mode 100644 index 000000000000..2cbeb6876a0d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de26a187c4db06115072a5132e1166b5b03368b0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de36bc309877917a18fd21acb30563c7e2f233c1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de36bc309877917a18fd21acb30563c7e2f233c1.hip new file mode 100644 index 000000000000..c58cb2d1f3ad --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de36bc309877917a18fd21acb30563c7e2f233c1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de5359f0fba3da9dfed06ddbea8fe2a33a9cf40c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de5359f0fba3da9dfed06ddbea8fe2a33a9cf40c.hip new file mode 100644 index 000000000000..06b3d532e344 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de5359f0fba3da9dfed06ddbea8fe2a33a9cf40c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de6683d175affaa5ff261ab8503f64172d8eba8b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de6683d175affaa5ff261ab8503f64172d8eba8b.hip new file mode 100644 index 000000000000..f17c1d628825 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de6683d175affaa5ff261ab8503f64172d8eba8b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de7eb562a7eff31d589e12945d80233aac202ae2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de7eb562a7eff31d589e12945d80233aac202ae2.hip new file mode 100644 index 000000000000..cba1f9164df1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de7eb562a7eff31d589e12945d80233aac202ae2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de85901d66dc04b1143bb6404445baf65693b781.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de85901d66dc04b1143bb6404445baf65693b781.hip new file mode 100644 index 000000000000..62d93f596d46 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_de85901d66dc04b1143bb6404445baf65693b781.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_deb9ec2cccab94920e40f62a1f0f094acd919d07.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_deb9ec2cccab94920e40f62a1f0f094acd919d07.hip new file mode 100644 index 000000000000..e53a12d64ba1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_deb9ec2cccab94920e40f62a1f0f094acd919d07.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df0b2bcba57e77d975ec5304fc50cbd09cddf4bb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df0b2bcba57e77d975ec5304fc50cbd09cddf4bb.hip new file mode 100644 index 000000000000..8658271826c3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df0b2bcba57e77d975ec5304fc50cbd09cddf4bb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df4bb75ca79f805a81fbad750ad22f6d22b0d8ff.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df4bb75ca79f805a81fbad750ad22f6d22b0d8ff.hip new file mode 100644 index 000000000000..389c1d5340f5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df4bb75ca79f805a81fbad750ad22f6d22b0d8ff.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df4c9eb48da49a61957537270d94e56cb4e426be.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df4c9eb48da49a61957537270d94e56cb4e426be.hip new file mode 100644 index 000000000000..b9354ed7dd1c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df4c9eb48da49a61957537270d94e56cb4e426be.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df5b1c6758d4b8540158299dd0362297083084c2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df5b1c6758d4b8540158299dd0362297083084c2.hip new file mode 100644 index 000000000000..20337081732b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df5b1c6758d4b8540158299dd0362297083084c2.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df645b3888dc8d1df50c47c0d75822eebd3eb019.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df645b3888dc8d1df50c47c0d75822eebd3eb019.hip new file mode 100644 index 000000000000..b44989396c5f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df645b3888dc8d1df50c47c0d75822eebd3eb019.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df66feebc9a0dcc508ce002c255154622875e524.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df66feebc9a0dcc508ce002c255154622875e524.hip new file mode 100644 index 000000000000..349d2b5d5c64 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_df66feebc9a0dcc508ce002c255154622875e524.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dfcd68acfca68d1acac94f493e25be0ef20f209f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dfcd68acfca68d1acac94f493e25be0ef20f209f.hip new file mode 100644 index 000000000000..fc4eaf24627d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_dfcd68acfca68d1acac94f493e25be0ef20f209f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e02a198f23c409b715761b702d7b0e6e5992701f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e02a198f23c409b715761b702d7b0e6e5992701f.hip new file mode 100644 index 000000000000..a1a5e869ef39 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e02a198f23c409b715761b702d7b0e6e5992701f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e035773419a9b3631698a3d375d829af55f7731e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e035773419a9b3631698a3d375d829af55f7731e.hip new file mode 100644 index 000000000000..d8a8fb800308 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e035773419a9b3631698a3d375d829af55f7731e.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e088f0f7363804cf5403adef70828ab32d09a02a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e088f0f7363804cf5403adef70828ab32d09a02a.hip new file mode 100644 index 000000000000..34543da0b0f1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e088f0f7363804cf5403adef70828ab32d09a02a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e0966fa1ff013e477b1706928de6cb7f8587c154.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e0966fa1ff013e477b1706928de6cb7f8587c154.hip new file mode 100644 index 000000000000..f4d7ae841c51 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e0966fa1ff013e477b1706928de6cb7f8587c154.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e09d9baa269dfbb30b714389d1733be51cc419b7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e09d9baa269dfbb30b714389d1733be51cc419b7.hip new file mode 100644 index 000000000000..822657151301 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e09d9baa269dfbb30b714389d1733be51cc419b7.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 64, + 256, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256, + ck_tile::fp16_t, + false, + false, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e0e48d7edfe9513f24ad9fae68cac3aa940b17dd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e0e48d7edfe9513f24ad9fae68cac3aa940b17dd.hip new file mode 100644 index 000000000000..94176d35b88a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e0e48d7edfe9513f24ad9fae68cac3aa940b17dd.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e10f47a44400de385ddbeb99475b717c5646fb41.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e10f47a44400de385ddbeb99475b717c5646fb41.hip new file mode 100644 index 000000000000..a3f727c71b1e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e10f47a44400de385ddbeb99475b717c5646fb41.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e11a3b7d4fdfed64e64f7a95dbc64eff541092d6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e11a3b7d4fdfed64e64f7a95dbc64eff541092d6.hip new file mode 100644 index 000000000000..d16a0f228f2b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e11a3b7d4fdfed64e64f7a95dbc64eff541092d6.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::bf16_t, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e13b86fe4e153e0bfa8d1e75f3641fe32b0c5149.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e13b86fe4e153e0bfa8d1e75f3641fe32b0c5149.hip new file mode 100644 index 000000000000..c760dc0c539f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e13b86fe4e153e0bfa8d1e75f3641fe32b0c5149.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e16075c3a5fcfe63ba12e854bb1fed6873f014ab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e16075c3a5fcfe63ba12e854bb1fed6873f014ab.hip new file mode 100644 index 000000000000..d7034adb3298 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e16075c3a5fcfe63ba12e854bb1fed6873f014ab.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, false, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e16edb824cecf459a8ec51b8dc74b1e06369aceb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e16edb824cecf459a8ec51b8dc74b1e06369aceb.hip new file mode 100644 index 000000000000..44bf03929834 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e16edb824cecf459a8ec51b8dc74b1e06369aceb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e1c1a31a1d8556cbe0b6ea76faacc78855108539.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e1c1a31a1d8556cbe0b6ea76faacc78855108539.hip new file mode 100644 index 000000000000..b477771fb5fc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e1c1a31a1d8556cbe0b6ea76faacc78855108539.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e1cc934ba7baab1a2eb062df1e4ee5066e9ffbc3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e1cc934ba7baab1a2eb062df1e4ee5066e9ffbc3.hip new file mode 100644 index 000000000000..f0a895058383 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e1cc934ba7baab1a2eb062df1e4ee5066e9ffbc3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e1d85ad2c9d197f501267fe0804e6985802fbd18.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e1d85ad2c9d197f501267fe0804e6985802fbd18.hip new file mode 100644 index 000000000000..8ff56f381029 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e1d85ad2c9d197f501267fe0804e6985802fbd18.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2762543d3380185e304f84749a70db1b8d3dd8c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2762543d3380185e304f84749a70db1b8d3dd8c.hip new file mode 100644 index 000000000000..e30f1a1927e7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2762543d3380185e304f84749a70db1b8d3dd8c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e28fd64c2f2b27577109a984e6ab82f5f0fcb296.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e28fd64c2f2b27577109a984e6ab82f5f0fcb296.hip new file mode 100644 index 000000000000..c0f84899ac51 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e28fd64c2f2b27577109a984e6ab82f5f0fcb296.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2b629c37cf94134693ce455b8c88b72a39df7fe.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2b629c37cf94134693ce455b8c88b72a39df7fe.hip new file mode 100644 index 000000000000..ff3d0628f2fb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2b629c37cf94134693ce455b8c88b72a39df7fe.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2bf6805a489739abb77c13173d57723e9304afa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2bf6805a489739abb77c13173d57723e9304afa.hip new file mode 100644 index 000000000000..d8baeba584cc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2bf6805a489739abb77c13173d57723e9304afa.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2c9f955f227430c6224ebc347649386be7f01eb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2c9f955f227430c6224ebc347649386be7f01eb.hip new file mode 100644 index 000000000000..2900490d7833 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2c9f955f227430c6224ebc347649386be7f01eb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2deafd2f36cee29109fb824e0135407453adcfe.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2deafd2f36cee29109fb824e0135407453adcfe.hip new file mode 100644 index 000000000000..bc7255519535 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e2deafd2f36cee29109fb824e0135407453adcfe.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e3015c5d50481547aa5754d042d9d7040cf1c7ff.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e3015c5d50481547aa5754d042d9d7040cf1c7ff.hip new file mode 100644 index 000000000000..39c83dc91af7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e3015c5d50481547aa5754d042d9d7040cf1c7ff.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e307a1b0d5a8f94e0a0f4032f401d20b4b643523.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e307a1b0d5a8f94e0a0f4032f401d20b4b643523.hip new file mode 100644 index 000000000000..b13364aa13c3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e307a1b0d5a8f94e0a0f4032f401d20b4b643523.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e334e691714f0b99773c2ac515ed82de0f387065.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e334e691714f0b99773c2ac515ed82de0f387065.hip new file mode 100644 index 000000000000..3fa6911d7dec --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e334e691714f0b99773c2ac515ed82de0f387065.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e34b7e452a4db74189334697e3a240ad68085f0e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e34b7e452a4db74189334697e3a240ad68085f0e.hip new file mode 100644 index 000000000000..fbeed9e8bab1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e34b7e452a4db74189334697e3a240ad68085f0e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e389d0e4442cd8304081892ddc75043e68a6398c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e389d0e4442cd8304081892ddc75043e68a6398c.hip new file mode 100644 index 000000000000..4520e939e63c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e389d0e4442cd8304081892ddc75043e68a6398c.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e465193d97d43237c22c04478ca5833011d8dc8b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e465193d97d43237c22c04478ca5833011d8dc8b.hip new file mode 100644 index 000000000000..870b19560a3d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e465193d97d43237c22c04478ca5833011d8dc8b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e477abef05ff37ec27705eda51896e2aa3a04966.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e477abef05ff37ec27705eda51896e2aa3a04966.hip new file mode 100644 index 000000000000..e3a7e8eb65ee --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e477abef05ff37ec27705eda51896e2aa3a04966.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e4d9a2396ceccdadab24602f30e9070901a76dc7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e4d9a2396ceccdadab24602f30e9070901a76dc7.hip new file mode 100644 index 000000000000..1babb48d2ea2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e4d9a2396ceccdadab24602f30e9070901a76dc7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e502730dea6987e2c038446c448aa08bdcc23113.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e502730dea6987e2c038446c448aa08bdcc23113.hip new file mode 100644 index 000000000000..1a309d876572 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e502730dea6987e2c038446c448aa08bdcc23113.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e514c6b4bc75d95a150104a17972abae77cb47ed.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e514c6b4bc75d95a150104a17972abae77cb47ed.hip new file mode 100644 index 000000000000..d3da295a237c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e514c6b4bc75d95a150104a17972abae77cb47ed.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e52e3053f30f780f346fa6b7a836ad2554cb85df.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e52e3053f30f780f346fa6b7a836ad2554cb85df.hip new file mode 100644 index 000000000000..112edb20d127 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e52e3053f30f780f346fa6b7a836ad2554cb85df.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e56757fb17f5e94a6ba1fb14540a68c36d571159.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e56757fb17f5e94a6ba1fb14540a68c36d571159.hip new file mode 100644 index 000000000000..15d9834ce7a4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e56757fb17f5e94a6ba1fb14540a68c36d571159.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e578ec9e09d3b78dca6b5bf0be1538657f02f319.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e578ec9e09d3b78dca6b5bf0be1538657f02f319.hip new file mode 100644 index 000000000000..1f2162412177 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e578ec9e09d3b78dca6b5bf0be1538657f02f319.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5935fbda313d3518f142f43d46f56c600f69286.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5935fbda313d3518f142f43d46f56c600f69286.hip new file mode 100644 index 000000000000..3f35cd96001e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5935fbda313d3518f142f43d46f56c600f69286.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5b2bb9f8466de1ad5210e4c39ee7b8ecacdffa9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5b2bb9f8466de1ad5210e4c39ee7b8ecacdffa9.hip new file mode 100644 index 000000000000..fbd7b54057aa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5b2bb9f8466de1ad5210e4c39ee7b8ecacdffa9.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5b65fc519ea7cfcd19f7eddbc3acad6842ff558.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5b65fc519ea7cfcd19f7eddbc3acad6842ff558.hip new file mode 100644 index 000000000000..155a6049668a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5b65fc519ea7cfcd19f7eddbc3acad6842ff558.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5c5079636a4a31a849ce8a5af89d50330a74628.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5c5079636a4a31a849ce8a5af89d50330a74628.hip new file mode 100644 index 000000000000..4f09e38bc855 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5c5079636a4a31a849ce8a5af89d50330a74628.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + false, + false, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5ccd5f7ddc894b2717112cbfc766804e02b7bd1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5ccd5f7ddc894b2717112cbfc766804e02b7bd1.hip new file mode 100644 index 000000000000..46d477547a41 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e5ccd5f7ddc894b2717112cbfc766804e02b7bd1.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e618fb4e529104fc90069c8779ce5463460bd516.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e618fb4e529104fc90069c8779ce5463460bd516.hip new file mode 100644 index 000000000000..32d64eb6f567 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e618fb4e529104fc90069c8779ce5463460bd516.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e638053e01268a4c5883620fc6a9901951e2e01a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e638053e01268a4c5883620fc6a9901951e2e01a.hip new file mode 100644 index 000000000000..fa466e7cc755 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e638053e01268a4c5883620fc6a9901951e2e01a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e639a1e84faa98477b05df71d363b9ff0f9b2760.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e639a1e84faa98477b05df71d363b9ff0f9b2760.hip new file mode 100644 index 000000000000..3c45952d2d85 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e639a1e84faa98477b05df71d363b9ff0f9b2760.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e68a9e05debd456a9975953f7b0d510e7a0f6978.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e68a9e05debd456a9975953f7b0d510e7a0f6978.hip new file mode 100644 index 000000000000..5fc8edd83a30 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e68a9e05debd456a9975953f7b0d510e7a0f6978.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6973d75297bd2c3432a7c88e8a9ee1c9ae693bf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6973d75297bd2c3432a7c88e8a9ee1c9ae693bf.hip new file mode 100644 index 000000000000..bcd0166a22d3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6973d75297bd2c3432a7c88e8a9ee1c9ae693bf.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6b53fb8d81148ff384d31a703bb4c2e7a5a33af.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6b53fb8d81148ff384d31a703bb4c2e7a5a33af.hip new file mode 100644 index 000000000000..3f007f131c1e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6b53fb8d81148ff384d31a703bb4c2e7a5a33af.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6e0ec1db1ea308e226f675e68e29b839e41b252.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6e0ec1db1ea308e226f675e68e29b839e41b252.hip new file mode 100644 index 000000000000..d1df360bbccf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6e0ec1db1ea308e226f675e68e29b839e41b252.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6e6b10e73733716e71ebf5a53703fb935fc5e02.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6e6b10e73733716e71ebf5a53703fb935fc5e02.hip new file mode 100644 index 000000000000..d79b86cd3364 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e6e6b10e73733716e71ebf5a53703fb935fc5e02.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7153f9a9b0b7c54ddf2debbe297efcffbb4fcfa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7153f9a9b0b7c54ddf2debbe297efcffbb4fcfa.hip new file mode 100644 index 000000000000..c4e6e1b8e8ee --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7153f9a9b0b7c54ddf2debbe297efcffbb4fcfa.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 128, + true, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, true, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e73a776ae4ba68c23acab1a5a6381684051738ab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e73a776ae4ba68c23acab1a5a6381684051738ab.hip new file mode 100644 index 000000000000..498bc125abfc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e73a776ae4ba68c23acab1a5a6381684051738ab.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e75c757c67aa23cb88e1aced6fcf36b7b28391db.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e75c757c67aa23cb88e1aced6fcf36b7b28391db.hip new file mode 100644 index 000000000000..19ff2f35e99c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e75c757c67aa23cb88e1aced6fcf36b7b28391db.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + true, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e75d492ac3a6ab75648056bcf26250a4aa929cfd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e75d492ac3a6ab75648056bcf26250a4aa929cfd.hip new file mode 100644 index 000000000000..bd6e14286bde --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e75d492ac3a6ab75648056bcf26250a4aa929cfd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e76879f8ff4796f48ad87ff8003f4f6e6adca9a0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e76879f8ff4796f48ad87ff8003f4f6e6adca9a0.hip new file mode 100644 index 000000000000..26e0c0a3b05a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e76879f8ff4796f48ad87ff8003f4f6e6adca9a0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7ae1294b6dea5c8b93c2b814fa7460c4047105b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7ae1294b6dea5c8b93c2b814fa7460c4047105b.hip new file mode 100644 index 000000000000..e5c376daa8c3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7ae1294b6dea5c8b93c2b814fa7460c4047105b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7b2eb64b66d46359fab44333c2c484f4c9dd5de.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7b2eb64b66d46359fab44333c2c484f4c9dd5de.hip new file mode 100644 index 000000000000..bf0cd8eecce6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7b2eb64b66d46359fab44333c2c484f4c9dd5de.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7c0a99e949baa5f3a7ee2d6e84427982f82f76d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7c0a99e949baa5f3a7ee2d6e84427982f82f76d.hip new file mode 100644 index 000000000000..26717f208199 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7c0a99e949baa5f3a7ee2d6e84427982f82f76d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7d37e7ee96c392fa24c02a9143438a3a7d05741.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7d37e7ee96c392fa24c02a9143438a3a7d05741.hip new file mode 100644 index 000000000000..b149181876e4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7d37e7ee96c392fa24c02a9143438a3a7d05741.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7de729aa50c10d8101ef504138c3769e3286753.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7de729aa50c10d8101ef504138c3769e3286753.hip new file mode 100644 index 000000000000..891b5d350573 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e7de729aa50c10d8101ef504138c3769e3286753.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e83c604d1b8260958becd1c7c209745ff9151715.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e83c604d1b8260958becd1c7c209745ff9151715.hip new file mode 100644 index 000000000000..90b09dd4d9ab --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e83c604d1b8260958becd1c7c209745ff9151715.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e89bcea4393593313d18a4aa6dcb44cd75bc828d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e89bcea4393593313d18a4aa6dcb44cd75bc828d.hip new file mode 100644 index 000000000000..df1b9f1e97c4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e89bcea4393593313d18a4aa6dcb44cd75bc828d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e8a9427f34bbf5ddb28a39161acc36806e68f2d0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e8a9427f34bbf5ddb28a39161acc36806e68f2d0.hip new file mode 100644 index 000000000000..fcb82a0e604d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e8a9427f34bbf5ddb28a39161acc36806e68f2d0.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::fp16_t, + false, + false, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e8d8fe5f4f8641998b8b805a20b2ca92d019ee59.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e8d8fe5f4f8641998b8b805a20b2ca92d019ee59.hip new file mode 100644 index 000000000000..b576966ede2f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e8d8fe5f4f8641998b8b805a20b2ca92d019ee59.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e8d9b65558398c0c10127b560807578ef117d7ed.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e8d9b65558398c0c10127b560807578ef117d7ed.hip new file mode 100644 index 000000000000..57f19f7cc94f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e8d9b65558398c0c10127b560807578ef117d7ed.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e907e8d1089557dfcc95a05160be5092e9119a53.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e907e8d1089557dfcc95a05160be5092e9119a53.hip new file mode 100644 index 000000000000..3616371a25be --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e907e8d1089557dfcc95a05160be5092e9119a53.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e95e3908479965856843317c8b0c42a6961dfd23.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e95e3908479965856843317c8b0c42a6961dfd23.hip new file mode 100644 index 000000000000..a61e1532a18a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e95e3908479965856843317c8b0c42a6961dfd23.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e986d5f8d5591f3e0f1cdfad19c38c420fd93023.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e986d5f8d5591f3e0f1cdfad19c38c420fd93023.hip new file mode 100644 index 000000000000..04c347066e51 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e986d5f8d5591f3e0f1cdfad19c38c420fd93023.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e9b04e6d5527ba0b8089ba8bdd264e2d5759338b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e9b04e6d5527ba0b8089ba8bdd264e2d5759338b.hip new file mode 100644 index 000000000000..6d0812a7ab09 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e9b04e6d5527ba0b8089ba8bdd264e2d5759338b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e9b53fa68641f45baabf40b7cfb8b35a9a1b9c7f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e9b53fa68641f45baabf40b7cfb8b35a9a1b9c7f.hip new file mode 100644 index 000000000000..b291dd76b356 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_e9b53fa68641f45baabf40b7cfb8b35a9a1b9c7f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea077e68dbc1bed2dd20a5f4dd35e0cad6330ee4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea077e68dbc1bed2dd20a5f4dd35e0cad6330ee4.hip new file mode 100644 index 000000000000..d74ad5dafa42 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea077e68dbc1bed2dd20a5f4dd35e0cad6330ee4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea591185b1c5f521023e250a26f742984255b241.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea591185b1c5f521023e250a26f742984255b241.hip new file mode 100644 index 000000000000..ce0ee1322174 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea591185b1c5f521023e250a26f742984255b241.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea62567e9ea16771d8445464c38f5a2931cb355a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea62567e9ea16771d8445464c38f5a2931cb355a.hip new file mode 100644 index 000000000000..47c92f36f553 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea62567e9ea16771d8445464c38f5a2931cb355a.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + false, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea6a6d4cc262ea838dbb83ee747112f95fa297bc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea6a6d4cc262ea838dbb83ee747112f95fa297bc.hip new file mode 100644 index 000000000000..fab6d688bcc1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ea6a6d4cc262ea838dbb83ee747112f95fa297bc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eab6cdc59bf216f7045f0cf5f221bb91ec415cd2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eab6cdc59bf216f7045f0cf5f221bb91ec415cd2.hip new file mode 100644 index 000000000000..5dc190b447d5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eab6cdc59bf216f7045f0cf5f221bb91ec415cd2.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eac353f963c52624cf79e82cc2b2c02eed94b677.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eac353f963c52624cf79e82cc2b2c02eed94b677.hip new file mode 100644 index 000000000000..cea6f89c5d1a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eac353f963c52624cf79e82cc2b2c02eed94b677.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eac5952f46f4f2bf06257b00661774eeed48a323.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eac5952f46f4f2bf06257b00661774eeed48a323.hip new file mode 100644 index 000000000000..f634a21a052b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eac5952f46f4f2bf06257b00661774eeed48a323.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, false, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eb278488b2cca114adca5e4614d86f92447f937a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eb278488b2cca114adca5e4614d86f92447f937a.hip new file mode 100644 index 000000000000..9830d036ca85 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eb278488b2cca114adca5e4614d86f92447f937a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ebb241b947a0adfc8e50c5d71765c14af24593ae.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ebb241b947a0adfc8e50c5d71765c14af24593ae.hip new file mode 100644 index 000000000000..881357123687 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ebb241b947a0adfc8e50c5d71765c14af24593ae.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ebb9abf5b09e63cbe76390bb46ff7cbefb3141f0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ebb9abf5b09e63cbe76390bb46ff7cbefb3141f0.hip new file mode 100644 index 000000000000..e8543aa36405 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ebb9abf5b09e63cbe76390bb46ff7cbefb3141f0.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec171210efd217c07d357fcf42e5372ad7e9abab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec171210efd217c07d357fcf42e5372ad7e9abab.hip new file mode 100644 index 000000000000..bfdd7adfdec1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec171210efd217c07d357fcf42e5372ad7e9abab.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec3deb1382003ac010d9bc1c59d1878d3ec7a727.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec3deb1382003ac010d9bc1c59d1878d3ec7a727.hip new file mode 100644 index 000000000000..61f7d97e7079 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec3deb1382003ac010d9bc1c59d1878d3ec7a727.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec51d24ab5f24e003ed6751ae8ae5b327892b15a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec51d24ab5f24e003ed6751ae8ae5b327892b15a.hip new file mode 100644 index 000000000000..38dee8f172ec --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec51d24ab5f24e003ed6751ae8ae5b327892b15a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec7ec8d547ee9713aa3b5b667f22cdcaa8f62b2d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec7ec8d547ee9713aa3b5b667f22cdcaa8f62b2d.hip new file mode 100644 index 000000000000..9e79206b7bd1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec7ec8d547ee9713aa3b5b667f22cdcaa8f62b2d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec7fc24902b1ebd8f2bf8088b0ecf6de8be8362d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec7fc24902b1ebd8f2bf8088b0ecf6de8be8362d.hip new file mode 100644 index 000000000000..8962a6918bcc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec7fc24902b1ebd8f2bf8088b0ecf6de8be8362d.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec9f63a538940e5ace02ae5b5ddc01f730adac4d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec9f63a538940e5ace02ae5b5ddc01f730adac4d.hip new file mode 100644 index 000000000000..16e339d93565 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ec9f63a538940e5ace02ae5b5ddc01f730adac4d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eca613eaa8471ad7da66d2f8f2b8e07f6e02b467.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eca613eaa8471ad7da66d2f8f2b8e07f6e02b467.hip new file mode 100644 index 000000000000..0d558baa021a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eca613eaa8471ad7da66d2f8f2b8e07f6e02b467.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ecd7dec90b3c62bf3a30bd75d3c6869529a06b01.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ecd7dec90b3c62bf3a30bd75d3c6869529a06b01.hip new file mode 100644 index 000000000000..b347eae247da --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ecd7dec90b3c62bf3a30bd75d3c6869529a06b01.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ece60111633db08f765b3c7cd5cd768cbd030255.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ece60111633db08f765b3c7cd5cd768cbd030255.hip new file mode 100644 index 000000000000..cc28ddb979d9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ece60111633db08f765b3c7cd5cd768cbd030255.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ed37ba962e0288e2840eb0925d016b5a7e3b3164.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ed37ba962e0288e2840eb0925d016b5a7e3b3164.hip new file mode 100644 index 000000000000..c7af660b7323 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ed37ba962e0288e2840eb0925d016b5a7e3b3164.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ed6bdf67720e938d538a867548ac3579b8238169.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ed6bdf67720e938d538a867548ac3579b8238169.hip new file mode 100644 index 000000000000..7665e4bb3add --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ed6bdf67720e938d538a867548ac3579b8238169.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ede81dbc4cb208ef6e684c76ba1eb451d37fe10c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ede81dbc4cb208ef6e684c76ba1eb451d37fe10c.hip new file mode 100644 index 000000000000..cc154d8bbcf7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ede81dbc4cb208ef6e684c76ba1eb451d37fe10c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee1a43f2210a8d1e5623411c95c33424cee5e747.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee1a43f2210a8d1e5623411c95c33424cee5e747.hip new file mode 100644 index 000000000000..9af65064c5af --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee1a43f2210a8d1e5623411c95c33424cee5e747.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee239db5a67c23a383590a651f0d8a0be43a13c7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee239db5a67c23a383590a651f0d8a0be43a13c7.hip new file mode 100644 index 000000000000..f5c3ef50b155 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee239db5a67c23a383590a651f0d8a0be43a13c7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee8e709eec7aef1fa681053c6d2969a5ff18c45c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee8e709eec7aef1fa681053c6d2969a5ff18c45c.hip new file mode 100644 index 000000000000..70a4e712867f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee8e709eec7aef1fa681053c6d2969a5ff18c45c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee974931e65d6b16b7c868d462b95dcae20b7513.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee974931e65d6b16b7c868d462b95dcae20b7513.hip new file mode 100644 index 000000000000..817937913404 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ee974931e65d6b16b7c868d462b95dcae20b7513.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eeb0e96b759e18cf703cfab0cda1385726f6e0a1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eeb0e96b759e18cf703cfab0cda1385726f6e0a1.hip new file mode 100644 index 000000000000..0d405aba7ece --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eeb0e96b759e18cf703cfab0cda1385726f6e0a1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eee408cf9456ff977aa7d12345e9b2f1e60639f1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eee408cf9456ff977aa7d12345e9b2f1e60639f1.hip new file mode 100644 index 000000000000..9346ac1d32cb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_eee408cf9456ff977aa7d12345e9b2f1e60639f1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef2ebb4a86e7ed0001de9c5e607b66fe8877409f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef2ebb4a86e7ed0001de9c5e607b66fe8877409f.hip new file mode 100644 index 000000000000..92fbcde9d8e2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef2ebb4a86e7ed0001de9c5e607b66fe8877409f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef40f0acf1885096efb840ec5600ec421c4db331.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef40f0acf1885096efb840ec5600ec421c4db331.hip new file mode 100644 index 000000000000..1f7e407864d7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef40f0acf1885096efb840ec5600ec421c4db331.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + false, false>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::fp16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, false, false, false, false>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef5421703cbfa63a58ec02701e245d479a1fbfc1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef5421703cbfa63a58ec02701e245d479a1fbfc1.hip new file mode 100644 index 000000000000..9bcf6ad9ff7c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef5421703cbfa63a58ec02701e245d479a1fbfc1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef7cc2aa1ffd38298b52764a93cd1271b4d92f8d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef7cc2aa1ffd38298b52764a93cd1271b4d92f8d.hip new file mode 100644 index 000000000000..c202ed8b7a90 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ef7cc2aa1ffd38298b52764a93cd1271b4d92f8d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_efaa0cb33c71cb8ca7b83dd0e7a6c7b01f6b50a9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_efaa0cb33c71cb8ca7b83dd0e7a6c7b01f6b50a9.hip new file mode 100644 index 000000000000..b0dedc50e795 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_efaa0cb33c71cb8ca7b83dd0e7a6c7b01f6b50a9.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 256, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<256, ck_tile::bf16_t, false, true, false>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_efb9e7d9af47cdf79f15f674f8976c05f08b0ce8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_efb9e7d9af47cdf79f15f674f8976c05f08b0ce8.hip new file mode 100644 index 000000000000..edc00c4b1239 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_efb9e7d9af47cdf79f15f674f8976c05f08b0ce8.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_efc6a7b25710f0626c3af534111b161e1459d2e1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_efc6a7b25710f0626c3af534111b161e1459d2e1.hip new file mode 100644 index 000000000000..2749b00dc64f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_efc6a7b25710f0626c3af534111b161e1459d2e1.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f01468c62c878295443981662e037ec5213cf7a3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f01468c62c878295443981662e037ec5213cf7a3.hip new file mode 100644 index 000000000000..13f81636bf35 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f01468c62c878295443981662e037ec5213cf7a3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f020134822739be6fa0bb3d98e9dec79f025324a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f020134822739be6fa0bb3d98e9dec79f025324a.hip new file mode 100644 index 000000000000..25854794b083 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f020134822739be6fa0bb3d98e9dec79f025324a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f0209426a8e6bfeef7d8ae7b16db791888142298.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f0209426a8e6bfeef7d8ae7b16db791888142298.hip new file mode 100644 index 000000000000..26df83891174 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f0209426a8e6bfeef7d8ae7b16db791888142298.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f028af9e5e3c25800dde938e991aaab4fc1d64aa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f028af9e5e3c25800dde938e991aaab4fc1d64aa.hip new file mode 100644 index 000000000000..4d37fd26387c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f028af9e5e3c25800dde938e991aaab4fc1d64aa.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f053c9c32518b895daaa3521827f37af78836fb8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f053c9c32518b895daaa3521827f37af78836fb8.hip new file mode 100644 index 000000000000..03107c712a5a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f053c9c32518b895daaa3521827f37af78836fb8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f069b38b26c30bc770f74c856e47eb498f5818e7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f069b38b26c30bc770f74c856e47eb498f5818e7.hip new file mode 100644 index 000000000000..c0e69ef42e7d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f069b38b26c30bc770f74c856e47eb498f5818e7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f0cad48d9bc80d58705ea60eb2dda4baad68cedb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f0cad48d9bc80d58705ea60eb2dda4baad68cedb.hip new file mode 100644 index 000000000000..228cd503785f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f0cad48d9bc80d58705ea60eb2dda4baad68cedb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f1246d1013d954a9316f4432c986d3be9459c548.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f1246d1013d954a9316f4432c986d3be9459c548.hip new file mode 100644 index 000000000000..bf43a6fbfb41 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f1246d1013d954a9316f4432c986d3be9459c548.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f12f1f1b679cabab04218037ef370d2c7e1fe332.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f12f1f1b679cabab04218037ef370d2c7e1fe332.hip new file mode 100644 index 000000000000..9800f89bbfda --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f12f1f1b679cabab04218037ef370d2c7e1fe332.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f15c41ddb04ec7f80235bb3db19198dd6b699713.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f15c41ddb04ec7f80235bb3db19198dd6b699713.hip new file mode 100644 index 000000000000..3be783819eab --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f15c41ddb04ec7f80235bb3db19198dd6b699713.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f18c74becc24a93427d9c0838784e9b6caad6e81.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f18c74becc24a93427d9c0838784e9b6caad6e81.hip new file mode 100644 index 000000000000..23f0e69bd4d0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f18c74becc24a93427d9c0838784e9b6caad6e81.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f1ecc90ad7b86791a9e6f73a582aeff30f393804.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f1ecc90ad7b86791a9e6f73a582aeff30f393804.hip new file mode 100644 index 000000000000..8daa2a9efd96 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f1ecc90ad7b86791a9e6f73a582aeff30f393804.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f21596e8c608a795ff971aea8e199db9e72b65d7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f21596e8c608a795ff971aea8e199db9e72b65d7.hip new file mode 100644 index 000000000000..f4e72e895379 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f21596e8c608a795ff971aea8e199db9e72b65d7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f24bd5b92ce6bba640b8ec6b4e53fe35902c5572.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f24bd5b92ce6bba640b8ec6b4e53fe35902c5572.hip new file mode 100644 index 000000000000..b1bd3ff5c76d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f24bd5b92ce6bba640b8ec6b4e53fe35902c5572.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f24d42e820adc1a26a428d59df7ffdd7f8580176.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f24d42e820adc1a26a428d59df7ffdd7f8580176.hip new file mode 100644 index 000000000000..a51d4c5c23c6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f24d42e820adc1a26a428d59df7ffdd7f8580176.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f24f26e45d5cf567d29fbe375fbf8abdec39186f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f24f26e45d5cf567d29fbe375fbf8abdec39186f.hip new file mode 100644 index 000000000000..2f2c89a1188b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f24f26e45d5cf567d29fbe375fbf8abdec39186f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f25b87c435bc5d7d85d738f3fdf68947d79f5a77.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f25b87c435bc5d7d85d738f3fdf68947d79f5a77.hip new file mode 100644 index 000000000000..b1946d6c24a1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f25b87c435bc5d7d85d738f3fdf68947d79f5a77.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f280e1639680ac1e5830a21f921bfe2cf364ef42.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f280e1639680ac1e5830a21f921bfe2cf364ef42.hip new file mode 100644 index 000000000000..13993a552faa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f280e1639680ac1e5830a21f921bfe2cf364ef42.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f2da112b1e07c44fc8a7f19368da203f6935049c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f2da112b1e07c44fc8a7f19368da203f6935049c.hip new file mode 100644 index 000000000000..fb2720ff39c7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f2da112b1e07c44fc8a7f19368da203f6935049c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f30316cfe49323638f71ba688dd8ff9b2266b335.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f30316cfe49323638f71ba688dd8ff9b2266b335.hip new file mode 100644 index 000000000000..58ae3c39fcf2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f30316cfe49323638f71ba688dd8ff9b2266b335.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3193ea266f3718398bc5622f8bc7042c3527a42.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3193ea266f3718398bc5622f8bc7042c3527a42.hip new file mode 100644 index 000000000000..6010154f029c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3193ea266f3718398bc5622f8bc7042c3527a42.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f34fdb8294257d951dcc9c4fa7ecf1192568b91b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f34fdb8294257d951dcc9c4fa7ecf1192568b91b.hip new file mode 100644 index 000000000000..393ae7e3f124 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f34fdb8294257d951dcc9c4fa7ecf1192568b91b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f36aaa63ed42a578b953ebd614318d44cf44e8a3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f36aaa63ed42a578b953ebd614318d44cf44e8a3.hip new file mode 100644 index 000000000000..9ac0f7143159 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f36aaa63ed42a578b953ebd614318d44cf44e8a3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f395bec57c3b2e6e169134dd8d20b287d7405134.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f395bec57c3b2e6e169134dd8d20b287d7405134.hip new file mode 100644 index 000000000000..2ab46ef081bd --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f395bec57c3b2e6e169134dd8d20b287d7405134.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3bf7ef503bb026258b3ec3d82d3ef1443046964.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3bf7ef503bb026258b3ec3d82d3ef1443046964.hip new file mode 100644 index 000000000000..39e85e7ddadc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3bf7ef503bb026258b3ec3d82d3ef1443046964.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3d0166931e4406873d8f552a5d5b61fde2391a3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3d0166931e4406873d8f552a5d5b61fde2391a3.hip new file mode 100644 index 000000000000..76fe2c03186a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3d0166931e4406873d8f552a5d5b61fde2391a3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3fd08d56f8a9be1a8dd104cdb1ac58e283b5064.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3fd08d56f8a9be1a8dd104cdb1ac58e283b5064.hip new file mode 100644 index 000000000000..c9a585a170c1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3fd08d56f8a9be1a8dd104cdb1ac58e283b5064.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3ff73f82aee3184849d04c2364eaa45c6d0de9c.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3ff73f82aee3184849d04c2364eaa45c6d0de9c.hip new file mode 100644 index 000000000000..3addb4329e6e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f3ff73f82aee3184849d04c2364eaa45c6d0de9c.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f42cf0e5fe479690883507028748b0cd3dc83cbb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f42cf0e5fe479690883507028748b0cd3dc83cbb.hip new file mode 100644 index 000000000000..bd0bdf1993bf --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f42cf0e5fe479690883507028748b0cd3dc83cbb.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4658c32d562f9d60c5ca1262a2e0df2375063bb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4658c32d562f9d60c5ca1262a2e0df2375063bb.hip new file mode 100644 index 000000000000..af4ca0cdd9d3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4658c32d562f9d60c5ca1262a2e0df2375063bb.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f48f8b681a405bfeba5aadaef40f32367ec5cd2b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f48f8b681a405bfeba5aadaef40f32367ec5cd2b.hip new file mode 100644 index 000000000000..275608fa5112 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f48f8b681a405bfeba5aadaef40f32367ec5cd2b.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, + ck_tile::fp16_t, + false, + false, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4900c0a5c0d03dc17d7a907ab40652d9920e756.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4900c0a5c0d03dc17d7a907ab40652d9920e756.hip new file mode 100644 index 000000000000..0c63b898e7d1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4900c0a5c0d03dc17d7a907ab40652d9920e756.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4a6438394dd3427f29aa0bbe58ad1f797c3c38d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4a6438394dd3427f29aa0bbe58ad1f797c3c38d.hip new file mode 100644 index 000000000000..26ce14570265 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4a6438394dd3427f29aa0bbe58ad1f797c3c38d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4b87f983a5e84582efa1663f84da76cf60b5f6f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4b87f983a5e84582efa1663f84da76cf60b5f6f.hip new file mode 100644 index 000000000000..dace41abd3a9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4b87f983a5e84582efa1663f84da76cf60b5f6f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4c803838f5644ccc6f04f7c8a6233fed0b6639e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4c803838f5644ccc6f04f7c8a6233fed0b6639e.hip new file mode 100644 index 000000000000..a784b433b881 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4c803838f5644ccc6f04f7c8a6233fed0b6639e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4df1cbfbaf67705820f125b474469ad7ebab0c0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4df1cbfbaf67705820f125b474469ad7ebab0c0.hip new file mode 100644 index 000000000000..a024563da303 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f4df1cbfbaf67705820f125b474469ad7ebab0c0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f50fa4ea674a590d0a817367ad9915a5fce20c51.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f50fa4ea674a590d0a817367ad9915a5fce20c51.hip new file mode 100644 index 000000000000..637d3ef0cc0f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f50fa4ea674a590d0a817367ad9915a5fce20c51.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f51f1a11f778d99a00aa5959a3e58a41fcbfb1e3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f51f1a11f778d99a00aa5959a3e58a41fcbfb1e3.hip new file mode 100644 index 000000000000..a7219ad0dbe6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f51f1a11f778d99a00aa5959a3e58a41fcbfb1e3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f525b59df454ccf53da6cb201e0aa8d09f52a2ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f525b59df454ccf53da6cb201e0aa8d09f52a2ad.hip new file mode 100644 index 000000000000..fc02d2c2a9f0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f525b59df454ccf53da6cb201e0aa8d09f52a2ad.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f57f84892e2a8496169b7406e63b0d4f5aa63aaf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f57f84892e2a8496169b7406e63b0d4f5aa63aaf.hip new file mode 100644 index 000000000000..d9d8e7ec6c5a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f57f84892e2a8496169b7406e63b0d4f5aa63aaf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f5803aadd93e33567aa6b23100ce4fbb6c040dd6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f5803aadd93e33567aa6b23100ce4fbb6c040dd6.hip new file mode 100644 index 000000000000..5a14831ce687 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f5803aadd93e33567aa6b23100ce4fbb6c040dd6.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f5f1797f6b672a55476348571ce17645c8a62869.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f5f1797f6b672a55476348571ce17645c8a62869.hip new file mode 100644 index 000000000000..6a4c999990e0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f5f1797f6b672a55476348571ce17645c8a62869.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f6566441ac3074578cfe45758ba0583c0da0a5ab.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f6566441ac3074578cfe45758ba0583c0da0a5ab.hip new file mode 100644 index 000000000000..7c731eb83cd1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f6566441ac3074578cfe45758ba0583c0da0a5ab.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f672bf80a78885428b2c02e522426470653a7351.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f672bf80a78885428b2c02e522426470653a7351.hip new file mode 100644 index 000000000000..f217112dc336 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f672bf80a78885428b2c02e522426470653a7351.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f682399cd6412fed6a1141296a7e4d42078f7b29.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f682399cd6412fed6a1141296a7e4d42078f7b29.hip new file mode 100644 index 000000000000..199d3a13f17c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f682399cd6412fed6a1141296a7e4d42078f7b29.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f6856ca950bcf173571766c3f04de4163be0402e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f6856ca950bcf173571766c3f04de4163be0402e.hip new file mode 100644 index 000000000000..2e1c1aa05a5e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f6856ca950bcf173571766c3f04de4163be0402e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f69548d6cced86c21c09c6475237a0cb926df0ed.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f69548d6cced86c21c09c6475237a0cb926df0ed.hip new file mode 100644 index 000000000000..3d3f5a67f990 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f69548d6cced86c21c09c6475237a0cb926df0ed.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f69878f4ca8cfe6b8d8748766f66a1ef8eab20ad.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f69878f4ca8cfe6b8d8748766f66a1ef8eab20ad.hip new file mode 100644 index 000000000000..60737953fb14 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f69878f4ca8cfe6b8d8748766f66a1ef8eab20ad.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f6f102a388ffb05c690a20a29cfe0b35a35eed61.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f6f102a388ffb05c690a20a29cfe0b35a35eed61.hip new file mode 100644 index 000000000000..42ee427842eb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f6f102a388ffb05c690a20a29cfe0b35a35eed61.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f7035f4bfd8f2f427720a07e3c311bccc1dba683.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f7035f4bfd8f2f427720a07e3c311bccc1dba683.hip new file mode 100644 index 000000000000..14cf83ee896a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f7035f4bfd8f2f427720a07e3c311bccc1dba683.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f71f96ce4dcc7f789a8ace73c230c203b05ff6dc.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f71f96ce4dcc7f789a8ace73c230c203b05ff6dc.hip new file mode 100644 index 000000000000..002dbccee2e8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f71f96ce4dcc7f789a8ace73c230c203b05ff6dc.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f727911254904ce4341e4ff5f8bafc430b8cfbbf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f727911254904ce4341e4ff5f8bafc430b8cfbbf.hip new file mode 100644 index 000000000000..8e09dab39633 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f727911254904ce4341e4ff5f8bafc430b8cfbbf.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f731289837f915e2aec1bd01eef1b3c1b099864d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f731289837f915e2aec1bd01eef1b3c1b099864d.hip new file mode 100644 index 000000000000..65acf8e88035 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f731289837f915e2aec1bd01eef1b3c1b099864d.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 32, + false, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<32, + ck_tile::bf16_t, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f79def2b4edf6d18f6ef1d6b141f9e0435441f6a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f79def2b4edf6d18f6ef1d6b141f9e0435441f6a.hip new file mode 100644 index 000000000000..65836ac6f42d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f79def2b4edf6d18f6ef1d6b141f9e0435441f6a.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, false,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f7aa9c39b06e55bf4bc9f9a2a0fb075c9d4e69ce.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f7aa9c39b06e55bf4bc9f9a2a0fb075c9d4e69ce.hip new file mode 100644 index 000000000000..40ba9885c6d0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f7aa9c39b06e55bf4bc9f9a2a0fb075c9d4e69ce.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f7cf08242b3fb1c643d4149bec985b667b9d28fa.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f7cf08242b3fb1c643d4149bec985b667b9d28fa.hip new file mode 100644 index 000000000000..c0d7dcf105d3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f7cf08242b3fb1c643d4149bec985b667b9d28fa.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f851da732f397624717160f89271514bc334b59b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f851da732f397624717160f89271514bc334b59b.hip new file mode 100644 index 000000000000..616877eba637 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f851da732f397624717160f89271514bc334b59b.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f861d8693f82d22e2c5b1abbcbae5f30f4433e5e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f861d8693f82d22e2c5b1abbcbae5f30f4433e5e.hip new file mode 100644 index 000000000000..98a398a2dadb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f861d8693f82d22e2c5b1abbcbae5f30f4433e5e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f87790f260630f312b84888dcbdf849ce130ae59.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f87790f260630f312b84888dcbdf849ce130ae59.hip new file mode 100644 index 000000000000..056ee5b30a00 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f87790f260630f312b84888dcbdf849ce130ae59.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f87991cb7787a29d3ce4711b4ce04c5fb6a14ca9.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f87991cb7787a29d3ce4711b4ce04c5fb6a14ca9.hip new file mode 100644 index 000000000000..39324fd6e090 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f87991cb7787a29d3ce4711b4ce04c5fb6a14ca9.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f90410c26d7649e21e2ae5e32e7af89d84d2ea70.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f90410c26d7649e21e2ae5e32e7af89d84d2ea70.hip new file mode 100644 index 000000000000..9b832ba6b1a0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f90410c26d7649e21e2ae5e32e7af89d84d2ea70.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f92e9a82c879051d6fe3c42108f8a574187704af.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f92e9a82c879051d6fe3c42108f8a574187704af.hip new file mode 100644 index 000000000000..6f02a47ef1d3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f92e9a82c879051d6fe3c42108f8a574187704af.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f93bc23b8a4f1e0fc5c5756c4e1c835bf59dea09.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f93bc23b8a4f1e0fc5c5756c4e1c835bf59dea09.hip new file mode 100644 index 000000000000..912d254bd415 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f93bc23b8a4f1e0fc5c5756c4e1c835bf59dea09.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f93bf815b520a9d9e17b43bf9d7fb870751b6225.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f93bf815b520a9d9e17b43bf9d7fb870751b6225.hip new file mode 100644 index 000000000000..631f6d6710ca --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f93bf815b520a9d9e17b43bf9d7fb870751b6225.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f974b12e83e214c30995a25631d37df1478927af.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f974b12e83e214c30995a25631d37df1478927af.hip new file mode 100644 index 000000000000..cd39fec637f5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f974b12e83e214c30995a25631d37df1478927af.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::fp16_t, false, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f9824fb32933b27501ae8a7f43f460a2dda6a814.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f9824fb32933b27501ae8a7f43f460a2dda6a814.hip new file mode 100644 index 000000000000..47b1b0016c43 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f9824fb32933b27501ae8a7f43f460a2dda6a814.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f98a6b193fec3203eaa75819f6b51aa45a48f212.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f98a6b193fec3203eaa75819f6b51aa45a48f212.hip new file mode 100644 index 000000000000..ef354a336f02 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f98a6b193fec3203eaa75819f6b51aa45a48f212.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f9c58761c927b222112cb5cb6c9acb5d3c915785.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f9c58761c927b222112cb5cb6c9acb5d3c915785.hip new file mode 100644 index 000000000000..164fb143b34f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_f9c58761c927b222112cb5cb6c9acb5d3c915785.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fa16fa84278b489af253b52839786f94aeeac36f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fa16fa84278b489af253b52839786f94aeeac36f.hip new file mode 100644 index 000000000000..4576735604ab --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fa16fa84278b489af253b52839786f94aeeac36f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fa62a97675719c2e8e9bb97361b92ff1c7b9d2ef.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fa62a97675719c2e8e9bb97361b92ff1c7b9d2ef.hip new file mode 100644 index 000000000000..96f3ca0a42a4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fa62a97675719c2e8e9bb97361b92ff1c7b9d2ef.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fa85f869a92f0482605e52019828244b12e12b44.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fa85f869a92f0482605e52019828244b12e12b44.hip new file mode 100644 index 000000000000..d4ca3b055d0b --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fa85f869a92f0482605e52019828244b12e12b44.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fabdc143c29d5ca50ab1e96a814bda6d05b0d5d2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fabdc143c29d5ca50ab1e96a814bda6d05b0d5d2.hip new file mode 100644 index 000000000000..d0124f68af10 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fabdc143c29d5ca50ab1e96a814bda6d05b0d5d2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fac5a0f98b94530befd634891e42c424bb86f0e1.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fac5a0f98b94530befd634891e42c424bb86f0e1.hip new file mode 100644 index 000000000000..46659f7fb59f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fac5a0f98b94530befd634891e42c424bb86f0e1.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fac99c3c82b77946f6844699d2333cd532a78a26.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fac99c3c82b77946f6844699d2333cd532a78a26.hip new file mode 100644 index 000000000000..5f26cbb6af9d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fac99c3c82b77946f6844699d2333cd532a78a26.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_faf56e45b2240515e97fc1bfd552eb03b6de5094.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_faf56e45b2240515e97fc1bfd552eb03b6de5094.hip new file mode 100644 index 000000000000..ff068ebf08d7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_faf56e45b2240515e97fc1bfd552eb03b6de5094.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_faf686067fa433cea5e95dd523846dc881eff635.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_faf686067fa433cea5e95dd523846dc881eff635.hip new file mode 100644 index 000000000000..aa184595a304 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_faf686067fa433cea5e95dd523846dc881eff635.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb2fbb135d59028afcf867c2cf08edc323565528.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb2fbb135d59028afcf867c2cf08edc323565528.hip new file mode 100644 index 000000000000..668aa06ab5e4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb2fbb135d59028afcf867c2cf08edc323565528.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb4c15452f9155c5966990f09432e5eb7e28e785.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb4c15452f9155c5966990f09432e5eb7e28e785.hip new file mode 100644 index 000000000000..8e97e9ed3aea --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb4c15452f9155c5966990f09432e5eb7e28e785.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb4c5f8fecfbbe16e6648becb3b5ca89fa3d8a94.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb4c5f8fecfbbe16e6648becb3b5ca89fa3d8a94.hip new file mode 100644 index 000000000000..e80141ab174f --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb4c5f8fecfbbe16e6648becb3b5ca89fa3d8a94.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb5bb49928ce5515d7b297d5eadd4ec70a22d60b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb5bb49928ce5515d7b297d5eadd4ec70a22d60b.hip new file mode 100644 index 000000000000..8e0f6d9cd464 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb5bb49928ce5515d7b297d5eadd4ec70a22d60b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb79e1f9231692d736dbada062ed6821f34927bf.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb79e1f9231692d736dbada062ed6821f34927bf.hip new file mode 100644 index 000000000000..48413b3f03c8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb79e1f9231692d736dbada062ed6821f34927bf.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb9477a613665cebcad781389ba7c5a36f51efe2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb9477a613665cebcad781389ba7c5a36f51efe2.hip new file mode 100644 index 000000000000..54776cf48f8c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fb9477a613665cebcad781389ba7c5a36f51efe2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fba36678d5047ded97ee7a7ba9feb9569afdb6ea.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fba36678d5047ded97ee7a7ba9feb9569afdb6ea.hip new file mode 100644 index 000000000000..630108f41c52 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fba36678d5047ded97ee7a7ba9feb9569afdb6ea.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fba47fa8d9b5375bc408af68b67345ab9dba2eb8.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fba47fa8d9b5375bc408af68b67345ab9dba2eb8.hip new file mode 100644 index 000000000000..ac6c30797327 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fba47fa8d9b5375bc408af68b67345ab9dba2eb8.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fbea85b766bf0c918ee0baf24dffc6a5563d5105.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fbea85b766bf0c918ee0baf24dffc6a5563d5105.hip new file mode 100644 index 000000000000..3ba3f9fe627c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fbea85b766bf0c918ee0baf24dffc6a5563d5105.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fbeec221cd63adaedceec39db41ea942f99f5133.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fbeec221cd63adaedceec39db41ea942f99f5133.hip new file mode 100644 index 000000000000..c18420b352f0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fbeec221cd63adaedceec39db41ea942f99f5133.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc030b61ae20c4b7d9b2d10930a17e01e9e93328.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc030b61ae20c4b7d9b2d10930a17e01e9e93328.hip new file mode 100644 index 000000000000..4a3c485deb24 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc030b61ae20c4b7d9b2d10930a17e01e9e93328.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc1790325b59bd44b0a5f6cf9723a25fd845cba7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc1790325b59bd44b0a5f6cf9723a25fd845cba7.hip new file mode 100644 index 000000000000..815bcf0af226 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc1790325b59bd44b0a5f6cf9723a25fd845cba7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc1eb85a00017efdc610e4259d2abe935b85304f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc1eb85a00017efdc610e4259d2abe935b85304f.hip new file mode 100644 index 000000000000..64139de1e3a3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc1eb85a00017efdc610e4259d2abe935b85304f.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc5841a729099340d608e31023acbeaeade3e886.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc5841a729099340d608e31023acbeaeade3e886.hip new file mode 100644 index 000000000000..b0b857d26e6d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc5841a729099340d608e31023acbeaeade3e886.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc5ebf0f2200f37ccc0849e0c3745f6e2f00111d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc5ebf0f2200f37ccc0849e0c3745f6e2f00111d.hip new file mode 100644 index 000000000000..43b274eecde0 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc5ebf0f2200f37ccc0849e0c3745f6e2f00111d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc7b0916744b593435d8e1e7b6d874d760cd5e3b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc7b0916744b593435d8e1e7b6d874d760cd5e3b.hip new file mode 100644 index 000000000000..a6820e08447d --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc7b0916744b593435d8e1e7b6d874d760cd5e3b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc86c13e933cba40553ffba31d53aad27415ce4b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc86c13e933cba40553ffba31d53aad27415ce4b.hip new file mode 100644 index 000000000000..7b8d8bdeb398 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fc86c13e933cba40553ffba31d53aad27415ce4b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fcb0b08e29b2e1bf181fceceb9dc416e54f52b00.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fcb0b08e29b2e1bf181fceceb9dc416e54f52b00.hip new file mode 100644 index 000000000000..fd63719fa021 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fcb0b08e29b2e1bf181fceceb9dc416e54f52b00.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fcb6ef39c3db49f26f736d6c9221dd825409ec4e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fcb6ef39c3db49f26f736d6c9221dd825409ec4e.hip new file mode 100644 index 000000000000..444b07067e61 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fcb6ef39c3db49f26f736d6c9221dd825409ec4e.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fcbe827108d252b2f5847fa8e132c9c3e56a90a0.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fcbe827108d252b2f5847fa8e132c9c3e56a90a0.hip new file mode 100644 index 000000000000..cfa4f3dab02c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fcbe827108d252b2f5847fa8e132c9c3e56a90a0.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fccabea88b8e290688c1b360875d228e6fdf1624.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fccabea88b8e290688c1b360875d228e6fdf1624.hip new file mode 100644 index 000000000000..72bd02eae432 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fccabea88b8e290688c1b360875d228e6fdf1624.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd10a3b937e9659716925e39a01d794914b08e26.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd10a3b937e9659716925e39a01d794914b08e26.hip new file mode 100644 index 000000000000..b2b4fc033c4a --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd10a3b937e9659716925e39a01d794914b08e26.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<2, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + true, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd19d7614f2ed5da21a52ed172ef62cc07c9c01a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd19d7614f2ed5da21a52ed172ef62cc07c9c01a.hip new file mode 100644 index 000000000000..f01e1872b5ae --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd19d7614f2ed5da21a52ed172ef62cc07c9c01a.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + true, + false, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::bf16_t, + true, + true, + false, + false>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd26e43ca652e6f58ff48c356165aa4349833b55.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd26e43ca652e6f58ff48c356165aa4349833b55.hip new file mode 100644 index 000000000000..08035d2ca663 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd26e43ca652e6f58ff48c356165aa4349833b55.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd345632e0cae0d549ba79626a08b1885711deb6.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd345632e0cae0d549ba79626a08b1885711deb6.hip new file mode 100644 index 000000000000..64baf279f427 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd345632e0cae0d549ba79626a08b1885711deb6.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd3558b4c7a667dbc365c4c2ceda646975408f51.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd3558b4c7a667dbc365c4c2ceda646975408f51.hip new file mode 100644 index 000000000000..0e1cf08e17a1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd3558b4c7a667dbc365c4c2ceda646975408f51.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd614df484b263deae3b3c20adb0ce7b62eaa651.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd614df484b263deae3b3c20adb0ce7b62eaa651.hip new file mode 100644 index 000000000000..877c173a42e7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd614df484b263deae3b3c20adb0ce7b62eaa651.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd9cd1305633b62b68fb8474ce021f639f8492e7.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd9cd1305633b62b68fb8474ce021f639f8492e7.hip new file mode 100644 index 000000000000..3001308cb7f4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fd9cd1305633b62b68fb8474ce021f639f8492e7.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fde12cd366d6850ce26afce98e5076b695b4875b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fde12cd366d6850ce26afce98e5076b695b4875b.hip new file mode 100644 index 000000000000..1896571bff36 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fde12cd366d6850ce26afce98e5076b695b4875b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe245e9ea974adce2b9807d33b9ba12d916eaffb.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe245e9ea974adce2b9807d33b9ba12d916eaffb.hip new file mode 100644 index 000000000000..4bad8e1b5a5e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe245e9ea974adce2b9807d33b9ba12d916eaffb.hip @@ -0,0 +1,84 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>; +using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>; + +using fmha_shape_0 = ck_tile::TileFmhaShape, + fmha_warp_tile_0, + ck_tile::sequence<4, 1, 1>, + fmha_warp_tile_0, + true>; + +using fmha_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; + +using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_0, + false, + fmha_mask_0, + fmha_trait_0>; + +using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync< + fmha_pipeline_problem_0>; + +using fmha_epilogue_0 = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + true, true>>; + +using fmha_kernel_0 = + ck_tile::FmhaFwdKernel, + fmha_pipeline_0, + fmha_epilogue_0>; + +using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true, + ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{ + using k_ = fmha_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe72cdd69944d2d765478d4aed13066a02b76f6d.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe72cdd69944d2d765478d4aed13066a02b76f6d.hip new file mode 100644 index 000000000000..e6856b66f61c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe72cdd69944d2d765478d4aed13066a02b76f6d.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe8b8c3525fe86a20a2d6c69585f3e36c16caabd.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe8b8c3525fe86a20a2d6c69585f3e36c16caabd.hip new file mode 100644 index 000000000000..6957e7bff394 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe8b8c3525fe86a20a2d6c69585f3e36c16caabd.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + true, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + true, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe97b7adcd67ed9bda8831d1f3f1ca7590c6d251.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe97b7adcd67ed9bda8831d1f3f1ca7590c6d251.hip new file mode 100644 index 000000000000..bc042bf645c1 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe97b7adcd67ed9bda8831d1f3f1ca7590c6d251.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + true, + false, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe9d98dbec5096a89b116f85675af772f023014a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe9d98dbec5096a89b116f85675af772f023014a.hip new file mode 100644 index 000000000000..bbc4cd7e95fa --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fe9d98dbec5096a89b116f85675af772f023014a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_feb5e77111fe1e20bafdb83a925b5faeeb6214af.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_feb5e77111fe1e20bafdb83a925b5faeeb6214af.hip new file mode 100644 index 000000000000..619437013f9c --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_feb5e77111fe1e20bafdb83a925b5faeeb6214af.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fecd7501265b4c4dcf015485e63e2324304f70d3.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fecd7501265b4c4dcf015485e63e2324304f70d3.hip new file mode 100644 index 000000000000..521da69e5203 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fecd7501265b4c4dcf015485e63e2324304f70d3.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fecffa403b3631b1957e1a9a06f18fdb3b4eee5f.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fecffa403b3631b1957e1a9a06f18fdb3b4eee5f.hip new file mode 100644 index 000000000000..aa3b11549de4 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fecffa403b3631b1957e1a9a06f18fdb3b4eee5f.hip @@ -0,0 +1,71 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_bwd_dot_do_o_trait_0 = + ck_tile::TileFmhaBwdOGradDotOTraits; + +using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + 32, + false, + fmha_bwd_dot_do_o_trait_0>; + +using fmha_bwd_dot_do_o_0 = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_0 = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_0 = + fmha_bwd_dot_do_o_traits_<32, ck_tile::bf16_t, false, true, true>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{ + using k_ = fmha_bwd_dot_do_o_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ff453e3bdc9752cb7b81f7cc3056325a8b9a8ad4.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ff453e3bdc9752cb7b81f7cc3056325a8b9a8ad4.hip new file mode 100644 index 000000000000..cf7620a56dc5 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ff453e3bdc9752cb7b81f7cc3056325a8b9a8ad4.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ff6862dbdbb20bc63a650e1f93e9ac169bb702b2.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ff6862dbdbb20bc63a650e1f93e9ac169bb702b2.hip new file mode 100644 index 000000000000..26ea0d1d0354 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ff6862dbdbb20bc63a650e1f93e9ac169bb702b2.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffb5b7349a671b182d73c8016590f26fe06a4cba.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffb5b7349a671b182d73c8016590f26fe06a4cba.hip new file mode 100644 index 000000000000..6e646b9ee3d8 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffb5b7349a671b182d73c8016590f26fe06a4cba.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + true, + false, + true, + true, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffb8adef0cef91a86f36872407fea35df90e8f2b.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffb8adef0cef91a86f36872407fea35df90e8f2b.hip new file mode 100644 index 000000000000..2b475a9ecdba --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffb8adef0cef91a86f36872407fea35df90e8f2b.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::bf16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::bf16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffc6056d9fe125a4dbe08c1d86354e51f7daadd5.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffc6056d9fe125a4dbe08c1d86354e51f7daadd5.hip new file mode 100644 index 000000000000..3f713f229af9 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffc6056d9fe125a4dbe08c1d86354e51f7daadd5.hip @@ -0,0 +1,79 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_bwd_convert_dq_trait_0 = + ck_tile::TileFmhaBwdConvertQGradTraits; + +using fmha_bwd_convert_dq_pipeline_problem_0 = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 64, + false, + true, + fmha_bwd_convert_dq_trait_0>; + +using fmha_bwd_convert_dq_0 = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_0 = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<64, + ck_tile::fp16_t, + false, + true, + false, + true>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{ + using k_ = fmha_bwd_convert_dq_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffd868d49abdb769ab82c21508d655daf54b8a99.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffd868d49abdb769ab82c21508d655daf54b8a99.hip new file mode 100644 index 000000000000..6c055ce460d2 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_ffd868d49abdb769ab82c21508d655daf54b8a99.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + true>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + true>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + true, + true, + true, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fff7aa57cca501f221077124359a589b3a6f9d0a.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fff7aa57cca501f221077124359a589b3a6f9d0a.hip new file mode 100644 index 000000000000..29c2a526ee1e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fff7aa57cca501f221077124359a589b3a6f9d0a.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + true, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + false, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + false, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + false, + false, + false, + true>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fffbfcac254e33926131a71905e93f9cc0aef89e.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fffbfcac254e33926131a71905e93f9cc0aef89e.hip new file mode 100644 index 000000000000..a54a4399b8bb --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_ck_autogen_fffbfcac254e33926131a71905e93f9cc0aef89e.hip @@ -0,0 +1,144 @@ +// ========================================== +// THIS CODE IS AUTOGENERATED. DO NOT MODIFY. +// @generated +// ========================================== +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +// auto generated by generate.py +#include + +using fmha_dtype_0 = ck_tile::fp16_t; + +using fmha_block_tile_0 = ck_tile:: + sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>; +using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; +using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; +using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>; +using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; +using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; +using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; + +using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_0, + false, + false, + fmha_mask_0, + fmha_dropout_0, + fmha_bwd_trait_0>; + +using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; + +using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::KGradDataType, + true, + false>>; + +using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig::VGradDataType, + true, + false>>; + +using fmha_bwd_dq_dk_dv_kernel_0 = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32, + ck_tile::fp16_t, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::ALIBI, + false, + false, + true, + false, + false, + false>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); +#else + return 0.0; +#endif +} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; +#if (defined(__gfx90a__) || defined(__gfx942__)) + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); +#endif +} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{ + using k_ = fmha_bwd_dq_dk_dv_kernel_0; + return k_::GetName(); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd.hpp new file mode 100644 index 000000000000..90eb5c208699 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd.hpp @@ -0,0 +1,773 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +template +struct FmhaFwdTypeConfig; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using BiasDataType = ck_tile::half_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::half_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::half_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using BiasDataType = ck_tile::bf16_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf16_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::fp8_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::bf8_t; + using KDataType = ck_tile::bf8_t; + using VDataType = ck_tile::bf8_t; + using BiasDataType = ck_tile::bf8_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf8_t; +}; + +struct FmhaMasks +{ + using NoMask = ck_tile::GenericAttentionMask; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; + +// runtime args, some will passed to karg, some will used to compute grids/blocks +struct fmha_fwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + void* rand_val_ptr; + void* lse_ptr; + void* o_ptr; + + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* + seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + float scale_s; + float scale_p; + float scale_o; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_randval; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + + float p_drop; + bool s_randval; + + std::variant, std::pair> + drop_seed_offset; +}; + +struct fmha_fwd_splitkv_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + void* lse_acc_ptr; + void* o_acc_ptr; + void* lse_ptr; + void* o_ptr; + + void* block_table_ptr; + ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr + ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr + + const void* cache_batch_idx; + + // the real seqlen_q & seqlen_k are decided by following: + // batch mode: seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqlen_k + // group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + // kvcache mode (use same kernel as batch mode): + // seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + ck_tile::index_t num_splits; + + float scale_s; + float scale_p; + float scale_o; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_o_acc; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_lse_acc; + ck_tile::index_t nhead_stride_o_acc; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_lse_acc; + ck_tile::index_t batch_stride_o_acc; + ck_tile::index_t batch_stride_o; + ck_tile::index_t split_stride_lse_acc; + ck_tile::index_t split_stride_o_acc; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; +}; + +struct fmha_fwd_appendkv_args +{ + void* q_ptr; + void* k_ptr; + const void* knew_ptr; + void* v_ptr; + const void* vnew_ptr; + + const void* seqlen_k_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_knew; + ck_tile::index_t batch; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + const void* rotary_cos_ptr; // only used if 'rotary_dim' > 0 + const void* rotary_sin_ptr; // only used if 'rotary_dim' > 0 + ck_tile::index_t rotary_dim; + bool has_mask; + + void* block_table_ptr; + ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr + ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr + + const void* cache_batch_idx; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_knew; + ck_tile::index_t stride_v; + ck_tile::index_t stride_vnew; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_knew; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_vnew; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_knew; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_vnew; +}; + +template +auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + }(); + + dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(Kernel::kIsGroupMode) + { + return Kernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_acc_ptr, + args.o_acc_ptr, + args.batch, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_splits, + args.scale_s, + args.scale_p, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_o_acc, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.batch_stride_k, // only used for paged-kvcache + args.batch_stride_v, // only used for paged-kvcache + args.split_stride_lse_acc, + args.split_stride_o_acc, + args.window_size_left, + args.window_size_right, + args.mask_type); + } + else + { // create batch mode kernel arguments + return Kernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_acc_ptr, + args.o_acc_ptr, + args.batch, + args.seqlen_q, + args.seqlen_k, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_splits, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.cache_batch_idx, + args.scale_s, + args.scale_p, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_o_acc, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.split_stride_lse_acc, + args.split_stride_o_acc, + args.window_size_left, + args.window_size_right, + args.mask_type); + } + }(); + + dim3 grids = + Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.num_splits); + + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel argumentszs + if constexpr(Kernel::kIsGroupMode) + { + return Kernel::MakeKargsImpl(args.lse_acc_ptr, + args.o_acc_ptr, + args.lse_ptr, + args.o_ptr, + args.batch, + args.seqstart_q_ptr, + args.hdim_v, + args.num_splits, + args.scale_o, + args.stride_o_acc, + args.stride_o, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.nhead_stride_lse, + args.nhead_stride_o, + args.split_stride_lse_acc, + args.split_stride_o_acc); + } + else + { // create batch mode kernel arguments + return Kernel::MakeKargsImpl(args.lse_acc_ptr, + args.o_acc_ptr, + args.lse_ptr, + args.o_ptr, + args.batch, + args.seqlen_q, + args.hdim_v, + args.num_splits, + args.scale_o, + args.stride_o_acc, + args.stride_o, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.batch_stride_lse, + args.batch_stride_o, + args.split_stride_lse_acc, + args.split_stride_o_acc); + } + }(); + + dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = Kernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.knew_ptr, + args.v_ptr, + args.vnew_ptr, + args.seqlen_q, + args.seqlen_k_ptr, + args.seqlen_knew, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.rotary_cos_ptr, + args.rotary_sin_ptr, + args.rotary_dim, + args.has_mask, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.cache_batch_idx, + args.stride_q, + args.stride_k, + args.stride_knew, + args.stride_v, + args.stride_vnew, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_knew, + args.nhead_stride_v, + args.nhead_stride_vnew, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_knew, + args.batch_stride_v, + args.batch_stride_vnew); + + dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.seqlen_knew); + + return ck_tile::make_tuple(kargs, grids); +} + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_fwd_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr ck_tile::index_t kK1 = kK1_; + static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kHasDropout = kHasDropout_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); + +template +struct fmha_fwd_splitkv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr ck_tile::index_t kK1 = kK1_; + static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr bool kIsPagedKV = kIsPagedKV_; +}; + +template +void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args); + +template +std::string fmha_fwd_splitkv_get_name_(); + +template +struct fmha_fwd_splitkv_combine_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args); + +template +std::string fmha_fwd_splitkv_combine_get_name_(); + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_fwd_appendkv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kTileSizeS = kTileSizeS_; + static constexpr ck_tile::index_t kTileSizeSk = kTileSizeSk_; + static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_; + static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSk = kPadSk_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr auto RotaryEnum = RotaryEnum_; + static constexpr bool kIsPagedKV = kIsPagedKV_; +}; + +template +float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args); + +// This is the public API, will be generated by script +struct fmha_fwd_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_lse; + bool has_dropout; + bool do_fp8_static_quant; + // TODO: padding check is inside this api +}; +float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); + +struct fmha_fwd_splitkv_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_lse; + bool do_fp8_static_quant; + // TODO: padding check is inside this api +}; +float fmha_fwd_splitkv(fmha_fwd_splitkv_traits, + fmha_fwd_splitkv_args, + const ck_tile::stream_config&); + +struct fmha_fwd_appendkv_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_v_rowmajor; + rope_enum rope_type; +}; +float fmha_fwd_appendkv(fmha_fwd_appendkv_traits, + fmha_fwd_appendkv_args, + const ck_tile::stream_config&); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mask.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mask.hpp new file mode 100644 index 000000000000..133049057d78 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mask.hpp @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include +#include + +// keep this in sync with ck_tile::GenericAttentionMaskEnum +enum class mask_enum +{ + no_mask = 0, + mask_top_left, + mask_bottom_right, + window_generic, +}; + +struct mask_info +{ + mask_enum type; + ck_tile::index_t y, x; + ck_tile::index_t left, right; // FA style SWA left/right + + void serialize(std::ostream& os) const + { + if(type == mask_enum::no_mask) + os << "n"; + else if(type == mask_enum::mask_top_left) + os << "t(" << left << ":" << right << ")"; + else if(type == mask_enum::mask_bottom_right) + os << "b(" << left << ":" << right << ")"; + else + { + os << "g(" << y << ":" << x << ")"; + } + } + static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k) + { + ck_tile::index_t x_total = seqlen_k; + ck_tile::index_t y_total = seqlen_q; + mask_info tmp; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string t = str.substr(0, found_0); + std::string v = str.substr(found_0 + 1); + if(t == "xt" || t == "xb") + { + // xformer style sliding window attn from top-left + ck_tile::index_t window_size = atoi(v.c_str()); + ck_tile::index_t left_size = -1; + ck_tile::index_t right_size = 0; + if(window_size > 0) + { + left_size = window_size / 2; + right_size = window_size - 1 - left_size; + } + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + left_size, right_size, y_total, x_total, t == "xt"); + + tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right; + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = left_size; + tmp.right = right_size; + } + else + { + auto found_1 = v.find(","); + if(found_1 == std::string::npos) + { + printf("not supported value %s, %s\n", v.c_str(), str.c_str()); + assert(0); + } + tmp.type = mask_enum::window_generic; + ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str()); + ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); + // TODO: some validation + if(t == "t") + { + tmp.type = mask_enum::mask_top_left; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, true); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; + } + else if(t == "b") + { + tmp.type = mask_enum::mask_bottom_right; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, false); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; + } + else if(t == "g") + { + tmp.y = v0; + tmp.x = v1; + tmp.left = v0; // TODO: don't use this? + tmp.right = v1; + } + else + { + printf("not supported type %s, %s\n", t.c_str(), str.c_str()); + assert(0); + } + } + } + else + { + auto set_causal_top_left = [&]() { + tmp.type = mask_enum::mask_top_left; + tmp.y = seqlen_q; + tmp.x = 1; + tmp.left = -1; + tmp.right = 0; + }; + auto set_causal_bottom_right = [&]() { + tmp.type = mask_enum::mask_bottom_right; + tmp.y = seqlen_q; + tmp.x = seqlen_k - seqlen_q + 1; + tmp.left = -1; + tmp.right = 0; + }; + if(str == "t") + set_causal_top_left(); + else if(str == "b") + set_causal_bottom_right(); + else + { + tmp.type = static_cast(atoi(str.c_str())); + if(tmp.type == mask_enum::mask_top_left) + { + set_causal_top_left(); + } + else if(tmp.type == mask_enum::mask_bottom_right) + { + set_causal_bottom_right(); + } + } + } + return tmp; + } + + friend std::ostream& operator<<(std::ostream& os, const mask_info& mi) + { + mi.serialize(os); + return os; + } +}; diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip new file mode 100644 index 000000000000..7dbb4e1cb568 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip @@ -0,0 +1,407 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include +#include +#include + +namespace pytorch_flash { + +fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask, + std::string dtype, + int head_size, + bool has_dropout, + bool enable_alibi, + bool deterministic) +{ + return fmha_bwd_traits{head_size, + head_size, + dtype, + false, // is_group_mode + mask.type, + enable_alibi ? bias_enum::alibi : bias_enum::no_bias, + false, // has_dbias + has_dropout, + false, // s_randval + deterministic}; +} + +fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, + // sizes + const int b, + const int seqlen_q, + const int seqlen_k, + const int h, + const int h_k, + const int hdim, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + c10::optional &alibi_slopes_, + const at::Tensor out, + const at::Tensor softmax_lse, + const at::Tensor dout, + at::Tensor dq_acc, + at::Tensor d, + at::Tensor dq, + at::Tensor dk, + at::Tensor dv, + float softmax_scale, + float p_dropout, + std::pair drop_seed_offset) +{ + // q: (batch_size, seqlen_q, nheads, hdim) + ck_tile::index_t batch_stride_q = q.stride(0); + ck_tile::index_t stride_q = q.stride(1); + ck_tile::index_t nhead_stride_q = q.stride(2); + + // k: (batch_size, seqlen_k, nheads_k, hdim) + ck_tile::index_t batch_stride_k = k.stride(0); + ck_tile::index_t stride_k = k.stride(1); + ck_tile::index_t nhead_stride_k = k.stride(2); + + // v: (batch_size, seqlen_k, nheads_k, hdim) + ck_tile::index_t batch_stride_v = v.stride(0); + ck_tile::index_t stride_v = v.stride(1); + ck_tile::index_t nhead_stride_v = v.stride(2); + + // o: (batch_size, seqlen_q, nheads, hdim) + ck_tile::index_t batch_stride_o = out.stride(0); + ck_tile::index_t stride_o = out.stride(1); + ck_tile::index_t nhead_stride_o = out.stride(2); + + // lse: (batch_size, nheads, seqlen_q) + ck_tile::index_t batch_stride_lse = softmax_lse.stride(0); + ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1); + + // do: (batch_size, seqlen_q, nheads, hdim) + ck_tile::index_t batch_stride_do = dout.stride(0); + ck_tile::index_t stride_do = dout.stride(1); + ck_tile::index_t nhead_stride_do = dout.stride(2); + + // d: (batch_size, nheads, seqlen_q) + // CK assume d share the same stride with lse + + // dq: (batch_size, seqlen_q, nheads, hdim) + ck_tile::index_t batch_stride_dq = dq.stride(0); + ck_tile::index_t stride_dq = dq.stride(1); + ck_tile::index_t nhead_stride_dq = dq.stride(2); + + // dk_expanded: (batch_size, seqlen_k, nheads, hdim) + ck_tile::index_t batch_stride_dk = dk.stride(0); + ck_tile::index_t stride_dk = dk.stride(1); + ck_tile::index_t nhead_stride_dk = dk.stride(2); + + // dv_expanded: (batch_size, seqlen_k, nheads, hdim) + ck_tile::index_t batch_stride_dv = dv.stride(0); + ck_tile::index_t stride_dv = dv.stride(1); + ck_tile::index_t nhead_stride_dv = dv.stride(2); + + // dq_acc: (split, batch_size, seqlen_q, nheads, hdim) + ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0); + ck_tile::index_t batch_stride_dq_acc = dq_acc.stride(1); + ck_tile::index_t stride_dq_acc = dq_acc.stride(2); + ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(3); + + float p_undrop = 1.0 - p_dropout; + + void *alibi_slopes_ptr = nullptr; + ck_tile::index_t stride_alibi_slopes = 0; + + if (alibi_slopes_.has_value()) { + auto alibi_slopes = alibi_slopes_.value(); + CHECK_DEVICE(alibi_slopes); + TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); + TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({h}) || alibi_slopes.sizes() == at::IntArrayRef({b, h})); + alibi_slopes_ptr = alibi_slopes.data_ptr(); + // alibi_slopes:(batch_size, nheads) or (nhead) + stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; + } + + return fmha_bwd_args{q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + alibi_slopes_ptr, // bias + out.data_ptr(), + softmax_lse.data_ptr(), + dout.data_ptr(), + d.data_ptr(), + nullptr, // rand_val + dq.data_ptr(), + dk.data_ptr(), + dv.data_ptr(), + nullptr, // dbias + dq_acc.data_ptr(), // dq_acc + nullptr, // seqstart_q + nullptr, // seqstart_k + nullptr, // seqlen_k_ptr + seqlen_q, + seqlen_k, + b, + seqlen_q, // max_seqlen_q + seqlen_k, // max_seqlen_k + hdim, // hdim_q + hdim, // hdim_v + h, // nhead + h_k, // nhead_k + softmax_scale, + stride_q, + stride_k, + stride_v, + stride_alibi_slopes, + stride_o, + 0, // stride_randval + stride_do, + stride_dq_acc, + stride_dq, + stride_dk, + stride_dv, + 0, // stride_dbias, FA without bias + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + 0, // nhead_stride_bias, FA without bias + nhead_stride_o, + 0, // nhead_stride_randval + nhead_stride_do, + nhead_stride_lse, + nhead_stride_dq_acc, + nhead_stride_dq, + nhead_stride_dk, + nhead_stride_dv, + 0, // nhead_stride_dbias, FA without dbias + batch_stride_q, + batch_stride_k, + batch_stride_v, + 0 , // batch_stride_bias, FA without bias + batch_stride_o, + 0, // batch_stride_randval + batch_stride_do, + batch_stride_lse, + batch_stride_dq_acc, + batch_stride_dq, + batch_stride_dk, + batch_stride_dv, + 0 , // batch_stride_dbias, FA without dbias + split_stride_dq_acc, + mask.left, + mask.right, + static_cast(mask.type), + p_dropout, + p_undrop, + drop_seed_offset}; +} + +std::tuple +mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset) +{ +#ifdef FLASHATTENTION_DISABLE_BACKWARD + TORCH_CHECK(false, "This flash attention build does not support backward."); +#endif + if (is_causal) { window_size_right = 0; } + + bool is_dropout = p_dropout > 0.0; + auto stream = at::cuda::getCurrentHIPStream().stream(); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + + std::string q_dtype_str = q_dtype == at::kHalf ? "fp16" : "bf16"; + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + const int seqlen_q = sizes[1]; + const int num_heads = sizes[2]; + const int head_size_og = dout.size(3); // unpadded hdim + const int head_size_8x = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_8x % 8 == 0, "head_size_8x should be a multiple of 8"); + TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + TORCH_CHECK(head_size_8x == round_multiple(head_size_og, 8), "head_size_8x must be head_size_og rounded to a multiple of 8"); + + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + + mask_info mask; + if (is_causal) { + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; + mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual + } + else if (window_size_left == -1 && window_size_right == -1) { + mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask + } + else { + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); + mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local + } + + // q, k, v, out had been padded in mha_fwd + // dq_, dk_, dv_ are also padded tensor + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_8x); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_8x); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_8x); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_8x); + CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og); + + at::Tensor dq, dk, dv; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size_8x); + } else { + dq = at::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size_8x); + } else { + dk = at::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_8x); + } else { + dv = at::empty_like(v); + } + + at::Tensor dout_padded; + if (head_size_og % 8 != 0) { + dout_padded = at::pad(dout, {0, 8 - head_size_og % 8}); + } else { + dout_padded = dout; + } + + // Cast to char to avoid compiler warning about narrowing + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + + auto opts = q.options(); + auto softmax_d = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor dq_accum; + + if (!deterministic) { + dq_accum = at::zeros({1, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); + } else { + const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64; + const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0); + dq_accum = at::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); + } + + at::Tensor dk_expanded, dv_expanded; + if (num_heads_k != num_heads) { // MQA / GQA + dk_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size_8x}, opts); + dv_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size_8x}, opts); + } else { + dk_expanded = dk; + dv_expanded = dv; + } + + uint64_t drop_seed = 1, drop_offset = 0; + drop_seed = *philox_seed.data_ptr(); + drop_offset = *philox_offset.data_ptr(); + auto drop_seed_offset = std::make_pair(&drop_seed, &drop_offset); + + + if (seqlen_q > 0) { + ck_tile::stream_config stream_config{stream}; + dq.zero_(); // ck use atomic operation on dq + auto traits = + get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value(), deterministic); + + auto args = + get_ck_fmha_bwd_args( + mask, + batch_size, + seqlen_q, + seqlen_k, + num_heads, + num_heads_k, + head_size_8x, + q, + k, + v, + alibi_slopes_, + out, + softmax_lse, + dout_padded, + dq_accum, + softmax_d, + dq, + dk_expanded, + dv_expanded, + softmax_scale, + p_dropout, + drop_seed_offset); +#if (defined(__gfx90a__) || defined(__gfx942__)) + float t = fmha_bwd(traits, args, stream_config); + TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd"); +#endif + } else { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dk_expanded.zero_(); + dv_expanded.zero_(); + softmax_d.zero_(); + } + + // For MQA/GQA we need to sum dK and dV across the groups + if (num_heads_k != num_heads) { + at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {3}); + at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {3}); + } + if (head_size_og % 8 != 0) { + dq = dq.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); + dk = dk.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); + dv = dv.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); + } + + return { dq, dk, dv, softmax_d }; +} +} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip new file mode 100644 index 000000000000..f66dee9c95ff --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip @@ -0,0 +1,360 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include +#include +#include + + + +namespace pytorch_flash { + + +fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask, + std::string dtype, + int head_size, + bool has_dropout, + bool has_lse, + bool enable_alibi) +{ + return fmha_fwd_traits{head_size, + head_size, + dtype, + false, // is_group_mode + true, // is_v_rowmajor + mask.type, + enable_alibi ? bias_enum::alibi : bias_enum::no_bias, + has_lse, + has_dropout, + false}; // do_fp8_static_quant +} + +fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, + bool has_dropout_randval, + const mask_info &mask, + // sizes + const int b, + const int seqlen_q, + const int seqlen_k, + const int h, + const int h_k, + const int d, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + std::optional &alibi_slopes_, + at::Tensor out, + at::Tensor softmax_lse, + at::Tensor dropout_randval, + float softmax_scale, + float p_dropout, + std::pair drop_seed_offset) +{ + // q: (batch_size, seqlen_q, nheads, d) + // k: (batch_size, seqlen_k, nheads_k, d) + // v: (batch_size, seqlen_k, nheads_k, d) + // o: (batch_size, seqlen_q, nheads, d) + + // alibi_slopes:(batch_size, nheads) or (nhead) + // lse: (batch_size, nheads, seqlen_q) + // randval: (batch_size, nheads, seqlen_q, seqlen_k) + + ck_tile::index_t stride_q = q.stride(1); + ck_tile::index_t stride_k = k.stride(1); + ck_tile::index_t stride_v = v.stride(1); + ck_tile::index_t stride_o = out.stride(1); + ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(2) : 0; + + ck_tile::index_t nhead_stride_q = q.stride(2); + ck_tile::index_t nhead_stride_k = k.stride(2); + ck_tile::index_t nhead_stride_v = v.stride(2); + ck_tile::index_t nhead_stride_o = out.stride(2); + ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0; + ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0; + + ck_tile::index_t batch_stride_q = q.stride(0); + ck_tile::index_t batch_stride_k = k.stride(0); + ck_tile::index_t batch_stride_v = v.stride(0); + ck_tile::index_t batch_stride_o = out.stride(0); + + ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0; + ck_tile::index_t batch_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0; + + void *alibi_slopes_ptr = nullptr; + ck_tile::index_t stride_alibi_slopes = 0; + + if (alibi_slopes_.has_value()) { + auto alibi_slopes = alibi_slopes_.value(); + CHECK_DEVICE(alibi_slopes); + TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); + TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({h}) || alibi_slopes.sizes() == at::IntArrayRef({b, h})); + alibi_slopes_ptr = alibi_slopes.data_ptr(); + stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; + } + + return fmha_fwd_args{q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + alibi_slopes_ptr, // bias + has_dropout_randval ? dropout_randval.data_ptr() : nullptr, + has_lse ? softmax_lse.data_ptr() : nullptr, + out.data_ptr(), + nullptr, // seqstart_q + nullptr, // seqstart_k + nullptr, + seqlen_q, + seqlen_k, + b, + seqlen_q, // max_seqlen_q + d, // hdim_q + d, // hdim_v + h, // nhead + h_k, // nhead_k + softmax_scale, // scale_s + 1, // scale_p + 1, // scale_o + stride_q, + stride_k, + stride_v, + stride_alibi_slopes, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + 0, // nhead_stride_bias, FA without bias + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + 0, // batch_stride_bias, FA without bias + batch_stride_randval, + batch_stride_lse, + batch_stride_o, + mask.left, + mask.right, + static_cast(mask.type), + p_dropout, + has_dropout_randval, + drop_seed_offset}; +} +std::tuple +mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &out_, // batch_size x seqlen_q x num_heads xhead_size + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_dropout_randval, + c10::optional gen_) +{ + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + + std::string q_dtype_str = q_dtype == at::kHalf ? "fp16" : "bf16"; + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + int seqlen_q = sizes[1]; + int num_heads = sizes[2]; + const int head_size = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size <= 256, "CK only supports head dimension at most 256"); + TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + + // causal=true is the same as causal=false in this case + if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } + + mask_info mask; + if (is_causal) { + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + window_size_right = 0; + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; + mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual + } + else if (window_size_left == -1 && window_size_right == -1) { + mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask + } + else { + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); + mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local + } + + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value(); + const int ngroups = num_heads / num_heads_k; + at::Tensor temp_q = q; + if (seqlenq_ngroups_swapped) { + temp_q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); + seqlen_q = ngroups; + num_heads = num_heads_k; + } + + CHECK_SHAPE(temp_q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); + + + at::Tensor q_padded, k_padded, v_padded; + if (head_size % 8 != 0) { + q_padded = at::pad(temp_q, {0, 8 - head_size % 8}); + k_padded = at::pad(k, {0, 8 - head_size % 8}); + v_padded = at::pad(v, {0, 8 - head_size % 8}); + } + else { + q_padded = temp_q; + k_padded = k; + v_padded = v; + } + + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size); + if (seqlenq_ngroups_swapped) { + out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); + } + if (head_size % 8 != 0) { out = at::empty_like(q_padded); }; + } + else { + out = at::empty_like(q); + } + + auto round_multiple = [](int x, int m) { return (x + m -1) / m*m;}; + const int head_size_8x = round_multiple(head_size, 8); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + + auto opts = q.options(); + bool has_lse = true; + bool has_dropout = p_dropout > 0.0f; + + at::Tensor softmax_lse; + // TODO - check gradient, only training require lse + softmax_lse = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + + at::Tensor p; + if (return_dropout_randval) { + TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0"); + p = at::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts.dtype(at::kByte)); + } + else { + p = at::empty({ 0 }, opts); + } + + int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); + auto rng_state = at::empty({2}, opts.dtype(at::kLong)); + auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); + + + + at::Tensor seed_t, offset_t; + + if (p_dropout > 0.0) { + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + + auto philox_args = gen->philox_cuda_state(counter_offset); + + + + hipLaunchKernelGGL( + flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), philox_args, rng_state_ptr); + seed_t = at::scalar_tensor(at::Scalar(static_cast(rng_state_ptr[0])), at::dtype(at::kLong)); + offset_t = at::scalar_tensor(at::Scalar(static_cast(rng_state_ptr[1])), at::dtype(at::kLong)); + } + else + { + seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + } + + if (seqlen_k > 0) { + auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); + auto stream = at::cuda::getCurrentHIPStream().stream(); + ck_tile::stream_config stream_config{stream}; + + auto traits = + get_ck_fmha_fwd_traits( + mask, + q_dtype_str, + head_size_8x, + has_dropout, + has_lse, + alibi_slopes_.has_value()); + + auto args = + get_ck_fmha_fwd_args( + has_lse, + return_dropout_randval, + mask, + batch_size, + seqlen_q, + seqlen_k, + num_heads, + num_heads_k, + head_size_8x, + q, + k, + v, + alibi_slopes_, + out, + softmax_lse, + p, + softmax_scale, + p_dropout, + drop_seed_offset); +#if (defined(__gfx90a__) || defined(__gfx942__)) + float t = fmha_fwd(traits, args, stream_config); + TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd"); +#endif + } + else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + if (seqlenq_ngroups_swapped) { + out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); + q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); + softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); + } + return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p}; +} +} //namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_bwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_bwd_ck.hip new file mode 100644 index 000000000000..708096392d00 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_bwd_ck.hip @@ -0,0 +1,436 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include +#include +#include + + +namespace pytorch_flash { + + +fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask, + std::string dtype, + int head_size, + bool has_dropout, + bool enable_alibi, + bool deterministic) +{ + return fmha_bwd_traits{head_size, + head_size, + dtype, + true, // is_group_mode + mask.type, + enable_alibi ? bias_enum::alibi : bias_enum::no_bias, + false, // has_dbias + has_dropout, + false, // s_randval + deterministic}; +} + +fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, + // sizes + const int b, + const int max_seqlen_q, + const int max_seqlen_k, + const int h, + const int h_k, + const int hdim, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + const at::Tensor seqlens_q, + const at::Tensor seqlens_k, + c10::optional &alibi_slopes_, + const at::Tensor out, + const at::Tensor softmax_lse, + const at::Tensor dout, + at::Tensor dq_acc, + at::Tensor d, + at::Tensor dq, + at::Tensor dk, + at::Tensor dv, + float softmax_scale, + float p_dropout, + std::pair drop_seed_offset) +{ + ck_tile::index_t total_q = q.size(0); + ck_tile::index_t total_k = k.size(0); + + // q: (total_q, nheads, hdim) + ck_tile::index_t batch_stride_q = 0; + ck_tile::index_t stride_q = q.stride(0); + ck_tile::index_t nhead_stride_q = q.stride(1); + + // k: (total_k, nheads_k, hdim) + ck_tile::index_t batch_stride_k = 0; + ck_tile::index_t stride_k = k.stride(0); + ck_tile::index_t nhead_stride_k = k.stride(1); + + // v: (total_k, nheads_k, hdim) + ck_tile::index_t batch_stride_v = 0; + ck_tile::index_t stride_v = v.stride(0); + ck_tile::index_t nhead_stride_v = v.stride(1); + + // o: (total_q, nheads, hdim) + ck_tile::index_t batch_stride_o = 0; + ck_tile::index_t stride_o = out.stride(0); + ck_tile::index_t nhead_stride_o = out.stride(1); + + // lse: (nheads, total_q) + ck_tile::index_t batch_stride_lse = 0; + ck_tile::index_t nhead_stride_lse = softmax_lse.stride(0); + + // do: (total_q, nheads, hdim) + ck_tile::index_t batch_stride_do = 0; + ck_tile::index_t stride_do = dout.stride(0); + ck_tile::index_t nhead_stride_do = dout.stride(1); + + // d: (batch_size, nheads, max_seqlen_q) + // CK assume d share the same stride with lse + + // dq: (total_q, nheads, hdim) + ck_tile::index_t batch_stride_dq = 0; + ck_tile::index_t stride_dq = dq.stride(0); + ck_tile::index_t nhead_stride_dq = dq.stride(1); + + + // dk_expanded: (total_k, nheads, hdim) + ck_tile::index_t batch_stride_dk = 0; + ck_tile::index_t stride_dk = dk.stride(0); + ck_tile::index_t nhead_stride_dk = dk.stride(1); + + // dv_expanded: (total_k, nheads, hdim) + ck_tile::index_t batch_stride_dv = 0; + ck_tile::index_t stride_dv = dv.stride(0); + ck_tile::index_t nhead_stride_dv = dv.stride(1); + + // dq_acc: (split, total_q, nheads, hdim) + ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0); + ck_tile::index_t batch_stride_dq_acc = 0; + ck_tile::index_t stride_dq_acc = dq_acc.stride(1); + ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(2); + + float p_undrop = 1.0 - p_dropout; + + void *alibi_slopes_ptr = nullptr; + ck_tile::index_t stride_alibi_slopes = 0; + + if (alibi_slopes_.has_value()) { + auto alibi_slopes = alibi_slopes_.value(); + CHECK_DEVICE(alibi_slopes); + TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); + TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({h}) || alibi_slopes.sizes() == at::IntArrayRef({b, h})); + alibi_slopes_ptr = alibi_slopes.data_ptr(); + // alibi_slopes:(batch_size, nheads) or (nhead) + stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; + } + + return fmha_bwd_args{q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + alibi_slopes_ptr, // bias + out.data_ptr(), + softmax_lse.data_ptr(), + dout.data_ptr(), + d.data_ptr(), + nullptr, // rand_val + dq.data_ptr(), + dk.data_ptr(), + dv.data_ptr(), + nullptr, // dbias + dq_acc.data_ptr(), // dq_acc + seqlens_q.data_ptr(), // seqstart_q + seqlens_k.data_ptr(), // seqstart_k + nullptr, // seqlen_k_ptr + total_q, + total_k, + b, + max_seqlen_q, // max_seqlen_q + max_seqlen_k, // max_seqlen_k + hdim, // hdim_q + hdim, // hdim_v + h, // nhead + h_k, // nhead_k + softmax_scale, + stride_q, + stride_k, + stride_v, + stride_alibi_slopes, + stride_o, + 0, // stride_randval + stride_do, + stride_dq_acc, + stride_dq, + stride_dk, + stride_dv, + 0, // stride_dbias, FA without bias + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + 0, // nhead_stride_bias, FA without bias + nhead_stride_o, + 0, // nhead_stride_randval + nhead_stride_do, + nhead_stride_lse, + nhead_stride_dq_acc, + nhead_stride_dq, + nhead_stride_dk, + nhead_stride_dv, + 0, // nhead_stride_dbias, FA without dbias + batch_stride_q, + batch_stride_k, + batch_stride_v, + 0 , // batch_stride_bias, FA without bias + batch_stride_o, + 0, // batch_stride_randval + batch_stride_do, + batch_stride_lse, + batch_stride_dq_acc, + batch_stride_dq, + batch_stride_dk, + batch_stride_dv, + 0 , // batch_stride_dbias, FA without dbias + split_stride_dq_acc, + mask.left, + mask.right, + static_cast(mask.type), + p_dropout, + p_undrop, + drop_seed_offset}; +} + +std::tuple +mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_heads x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x s softmax logsumexp + c10::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + c10::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + c10::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset) +{ +#ifdef FLASHATTENTION_DISABLE_BACKWARD + TORCH_CHECK(false, "This flash attention build does not support backward."); +#endif + if (is_causal) { window_size_right = 0; } + + bool is_dropout = p_dropout > 0.0; + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32"); + + std::string q_dtype_str = q_dtype == at::kHalf ? "fp16" : "bf16"; + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + const auto sizes = q.sizes(); + + const int total_q = sizes[0]; + const int batch_size = cu_seqlens_q.numel() - 1; + const int num_heads = sizes[1]; + const int head_size_og = dout.size(2); + const int head_size_8x = sizes[2]; + const int total_k = k.size(0); + const int num_heads_k = k.size(1); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_8x % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + TORCH_CHECK(head_size_8x == round_multiple(head_size_og, 8), "head_size_8x must be head_size_og rounded to a multiple of 8"); + + if (window_size_left >= max_seqlen_k) { window_size_left = -1; } + if (window_size_right >= max_seqlen_k) { window_size_right = -1; } + + mask_info mask; + if (is_causal) { + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; + mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual + } + else if (window_size_left == -1 && window_size_right == -1) { + mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask + } + else { + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); + mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local + } + + // q, k, v, out had been padded in mha_fwd + // dq_, dk_, dv_ are also padded tensor + CHECK_SHAPE(q, total_q, num_heads, head_size_8x); + CHECK_SHAPE(k, total_k, num_heads_k, head_size_8x); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_8x); + CHECK_SHAPE(out, total_q, num_heads, head_size_8x); + CHECK_SHAPE(dout, total_q, num_heads, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + at::Tensor dq, dk, dv; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, total_q, num_heads, head_size_8x); + } else { + dq = at::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, total_k, num_heads_k, head_size_8x); + } else { + dk = at::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, total_k, num_heads_k, head_size_8x); + } else { + dv = at::empty_like(v); + } + + at::Tensor dout_padded; + if (head_size_og % 8 != 0) { + dout_padded = at::pad(dout, {0, 8 - head_size_og % 8}); + } else { + dout_padded = dout; + } + + // Cast to char to avoid compiler warning about narrowing + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + + auto opts = q.options(); + auto softmax_d = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor dq_accum; + + if (!deterministic) { + dq_accum = at::zeros({1, total_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); + } else { + const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64; + const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(max_seqlen_k, kN0); + dq_accum = at::zeros({nsplits, total_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); + } + + at::Tensor dk_expanded, dv_expanded; + if (num_heads_k != num_heads) { // MQA / GQA + dk_expanded = at::empty({total_k, num_heads, head_size_8x}, opts); + dv_expanded = at::empty({total_k, num_heads, head_size_8x}, opts); + } else { + dk_expanded = dk; + dv_expanded = dv; + } + + if(zero_tensors) { + dq.zero_(); + dk_expanded.zero_(); + dv_expanded.zero_(); + softmax_d.zero_(); + } + + uint64_t drop_seed = 1, drop_offset = 0; + + drop_seed = *philox_seed.data_ptr(); + drop_offset = *philox_offset.data_ptr(); + auto drop_seed_offset = std::make_pair(&drop_seed, &drop_offset); + + if (max_seqlen_q > 0) { + ck_tile::stream_config stream_config{stream}; + dq.zero_(); // ck use atomic operation on dq + auto traits = + get_ck_fmha_varlen_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value(), deterministic); + + auto args = + get_ck_fmha_varlen_bwd_args( + mask, + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads, + num_heads_k, + head_size_8x, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + out, + softmax_lse, + dout_padded, + dq_accum, + softmax_d, + dq, + dk_expanded, + dv_expanded, + softmax_scale, + p_dropout, + drop_seed_offset); +#if (defined(__gfx90a__) || defined(__gfx942__)) + float t = fmha_bwd(traits, args, stream_config); + TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd"); +#endif + } else { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dk_expanded.zero_(); + dv_expanded.zero_(); + softmax_d.zero_(); + } + + // For MQA/GQA we need to sum dK and dV across the groups + if (num_heads_k != num_heads) { + at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {2}); + at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {2}); + } + if (head_size_og % 8 != 0) { + dq = dq.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); + dk = dk.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); + dv = dv.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); + } + + return { dq, dk, dv, softmax_d }; +} +} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip new file mode 100644 index 000000000000..1ffdcd5b9bff --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip @@ -0,0 +1,364 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include +#include +#include + +namespace pytorch_flash { + +fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask, + std::string dtype, + int head_size, + bool has_dropout, + bool has_lse, + bool enable_alibi) +{ + return fmha_fwd_traits{head_size, + head_size, + dtype, + true, // is_group_mode + true, // is_v_rowmajor + mask.type, + enable_alibi ? bias_enum::alibi : bias_enum::no_bias, + has_lse, + has_dropout, + false}; // do_fp8_static_quant +} + +fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, + bool has_dropout_randval, + const mask_info &mask, + // sizes + const int b, + const int max_seqlen_q, + const int h, + const int h_k, + const int d, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + const at::Tensor seqlens_q, + const at::Tensor seqlens_k, + c10::optional &alibi_slopes_, + at::Tensor out, + at::Tensor softmax_lse, + at::Tensor dropout_randval, + + float softmax_scale, + float p_dropout, + std::pair drop_seed_offset) +{ + // q: (total_q, nheads, d) + // k: (total_k, nheads_k, d) + // v: (total_k, nheads_k, d) + // o: (total_q, nheads, d) + + // alibi_slopes:(batch, nheads) or (nhead) + // lse: (batch, nheads, max_seqlen_q) + // randval: (nheads, total_q, max_seqlen_k) + + ck_tile::index_t total_q = q.size(0); + ck_tile::index_t total_k = k.size(0); + + ck_tile::index_t stride_q = q.stride(0); + ck_tile::index_t stride_k = k.stride(0); + ck_tile::index_t stride_v = v.stride(0); + ck_tile::index_t stride_o = out.stride(0); + ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0; + + ck_tile::index_t nhead_stride_q = q.stride(1); + ck_tile::index_t nhead_stride_k = k.stride(1); + ck_tile::index_t nhead_stride_v = v.stride(1); + ck_tile::index_t nhead_stride_o = out.stride(1); + ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0; + ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0; + + ck_tile::index_t batch_stride_q = 0; + ck_tile::index_t batch_stride_k = 0; + ck_tile::index_t batch_stride_v = 0; + ck_tile::index_t batch_stride_o = 0; + + ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0; + ck_tile::index_t batch_stride_randval = 0; + + void *alibi_slopes_ptr = nullptr; + ck_tile::index_t stride_alibi_slopes = 0; + + if (alibi_slopes_.has_value()) { + auto alibi_slopes = alibi_slopes_.value(); + CHECK_DEVICE(alibi_slopes); + TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); + TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({h}) || alibi_slopes.sizes() == at::IntArrayRef({b, h})); + alibi_slopes_ptr = alibi_slopes.data_ptr(); + stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; + } + + return fmha_fwd_args{q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + alibi_slopes_ptr, // bias + has_dropout_randval ? dropout_randval.data_ptr() : nullptr, + has_lse ? softmax_lse.data_ptr() : nullptr, + out.data_ptr(), + seqlens_q.data_ptr(), // seqstart_q + seqlens_k.data_ptr(), // seqstart_k + nullptr, // seqlen_kpads + total_q, + total_k, + b, + max_seqlen_q, + d, // hdim_q + d, // hdim_v + h, // nhead + h_k, // nhead_k + softmax_scale, // scale_s + 1, // scale_p + 1, // scale_o + stride_q, + stride_k, + stride_v, + stride_alibi_slopes, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + 0, // nhead_stride_bias, FA without bias + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + 0, // batch_stride_bias, FA without bias + batch_stride_randval, + batch_stride_lse, + batch_stride_o, + mask.left, + mask.right, + static_cast(mask.type), + p_dropout, + has_dropout_randval, + drop_seed_offset}; +} + +std::tuple +mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + c10::optional & /*seqused_k*/, + c10::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_dropout_randval, + c10::optional gen_) +{ + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32"); + + std::string q_dtype_str = q_dtype == at::kHalf ? "fp16" : "bf16"; + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(cu_seqlens_q); + CHECK_DEVICE(cu_seqlens_k); + + // TODO - Support paged_KV + // const bool paged_KV = block_table_.has_value(); + // TORCH_CHECK(!paged_KV, "CK does not support paged_KV yet"); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + const auto sizes = q.sizes(); + + const int batch_size = cu_seqlens_q.numel() - 1; + int num_heads = sizes[1]; + const int head_size_og = sizes[2]; + const int num_heads_k = k.size(1); + + const int max_num_blocks_per_seq = 0; + const int num_blocks = 0; + + if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case + + // TODO + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + + const int total_q = q.size(0); + const int total_k = k.size(0); + + TORCH_CHECK(batch_size > 0, "batch size must be postive"); + TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (window_size_left >= max_seqlen_k) { window_size_left = -1; } + if (window_size_right >= max_seqlen_k) { window_size_right = -1; } + + mask_info mask; + + if (is_causal) { + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + window_size_right = 0; + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; + mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual + } + else if (window_size_left == -1 && window_size_right == -1) { + mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask + } + else { + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); + mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local + } + + CHECK_SHAPE(q, total_q, num_heads, head_size_og); + CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + at::Tensor q_padded, k_padded, v_padded; + if (head_size_og % 8 != 0) { + q_padded = at::pad(q, {0, 8 - head_size_og % 8}); + k_padded = at::pad(k, {0, 8 - head_size_og % 8}); + v_padded = at::pad(v, {0, 8 - head_size_og % 8}); + } + else { + q_padded = q; + k_padded = k; + v_padded = v; + } + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, total_q, num_heads, head_size_og); + + if (head_size_og % 8 != 0) { out = at::empty_like(q_padded); } + } + else { + out = at::empty_like(q_padded); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_8x = round_multiple(head_size_og, 8); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + + auto opts = q.options(); + bool has_lse = true; + bool has_dropout = p_dropout > 0.0f; + + at::Tensor softmax_lse; + // TODO - check gradient, only training require lse + softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + + at::Tensor p; + if (return_dropout_randval) { + TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0"); + p = at::empty({num_heads, total_q, max_seqlen_k}, opts.dtype(at::kByte)); + } + + if (zero_tensors) + { + out.zero_(); + softmax_lse.fill_(-std::numeric_limits::infinity()); + if (return_dropout_randval) {p.zero_();} + } + + int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); + auto rng_state = at::empty({2}, opts.dtype(at::kLong)); + auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); + + if (p_dropout > 0.0) { + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + auto philox_args = gen->philox_cuda_state(counter_offset); + hipLaunchKernelGGL( + flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), philox_args, rng_state_ptr); + } + + + if (max_seqlen_k > 0) { + auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); + auto stream = at::cuda::getCurrentHIPStream().stream(); + ck_tile::stream_config stream_config{stream}; + + auto traits = + get_ck_fmha_varlen_fwd_traits(mask, q_dtype_str, head_size_8x, has_dropout, has_lse, alibi_slopes_.has_value()); + + auto args = + get_ck_fmha_varlen_fwd_args( + has_lse, + return_dropout_randval, + mask, + batch_size, + max_seqlen_q, + num_heads, + num_heads_k, + head_size_8x, + q_padded, + k_padded, + v_padded, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + out, + softmax_lse, + p, + softmax_scale, + p_dropout, + drop_seed_offset); +#if (defined(__gfx90a__) || defined(__gfx942__)) + float t = fmha_fwd(traits, args, stream_config); + TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd"); +#endif + } + else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + at::Tensor out_padded = out; + if (head_size_og % 8 != 0) { + out = out.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); + if (out_.has_value()) { out_.value().copy_(out); } + } + + //return kludge -- TODO:: REMOVE + at::Tensor seed_t = at::empty({}, at::dtype(at::kLong)); + at::Tensor offset_t = at::empty({}, at::dtype(at::kLong)); + + return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p}; +} +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/rename_ck_autogen_files.output.txt b/aten/src/ATen/native/transformers/hip/flash_attn/ck/rename_ck_autogen_files.output.txt new file mode 100644 index 000000000000..78f844fd2a1e --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/rename_ck_autogen_files.output.txt @@ -0,0 +1,1810 @@ +fmha_bwd_api.hip -> fmha_ck_autogen_5919133d2ed892745013b2fc5d503414cf0a4d83.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2.hip -> fmha_ck_autogen_e11a3b7d4fdfed64e64f7a95dbc64eff541092d6.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2_deterministic.hip -> fmha_ck_autogen_01cb354dddef6e99e4ac843f2adafcddfc58d520.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2_pd.hip -> fmha_ck_autogen_1b3e7c8969027d3316875f33dc50fe022e05ce37.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2_pd_deterministic.hip -> fmha_ck_autogen_38273a2f8e6bbb42ba0b0871b6c95abb34531f33.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2_ps.hip -> fmha_ck_autogen_2d43460c011b8d5e01ea98c9b8ddce962de59a96.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2_ps_deterministic.hip -> fmha_ck_autogen_4c0c50a1fac82d47dff2357ee3ddbfa0b2c8d487.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2_psd.hip -> fmha_ck_autogen_2a3a980a26682d879c3a3425f3ba5be3f5761adf.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_batch_o2_psd_deterministic.hip -> fmha_ck_autogen_008f2429c678d13386a06e8d8b15c4b480940ff3.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_group_o2_ps.hip -> fmha_ck_autogen_811db756577b61cde9fe8279d956980db9ee21a4.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_group_o2_ps_deterministic.hip -> fmha_ck_autogen_492fbc418e829f89bcb8d93f8afd2869dd8dfccc.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_group_o2_psd.hip -> fmha_ck_autogen_75f2010bf6c478d2f0eba77e912697661306c1cb.hip +fmha_bwd_convert_dq_d128_bf16_b64x128_group_o2_psd_deterministic.hip -> fmha_ck_autogen_0153ec18d3ded0f8bdc6459ea5757ebd94d9faf2.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_batch_o2.hip -> fmha_ck_autogen_3eb2ea922daabbba131b90713e06d8caf5f30662.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_batch_o2_deterministic.hip -> fmha_ck_autogen_c0f76aff077c28f8afd7b22f284cf2894e08a043.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_batch_o2_pd.hip -> fmha_ck_autogen_f48f8b681a405bfeba5aadaef40f32367ec5cd2b.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_batch_o2_pd_deterministic.hip -> fmha_ck_autogen_4cabdafad0bf803223ba5e8f474cd59233dc48cb.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_batch_o2_ps.hip -> fmha_ck_autogen_0801c56831b4c6428200db6318638a2129bb197a.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_batch_o2_ps_deterministic.hip -> fmha_ck_autogen_91b9e2616c2fe0480096b1ccf0f74d584b220146.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_batch_o2_psd.hip -> fmha_ck_autogen_4f1e1c969b57659e7e1367ac9ba10ed5ef5b69a9.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_batch_o2_psd_deterministic.hip -> fmha_ck_autogen_ecd7dec90b3c62bf3a30bd75d3c6869529a06b01.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_group_o2_ps.hip -> fmha_ck_autogen_88ea5b5346c87cc4fc1e841c518080df4ab811a2.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_group_o2_ps_deterministic.hip -> fmha_ck_autogen_4395d3c96b3f4556b9765fd0a3b5701b2fb10948.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_group_o2_psd.hip -> fmha_ck_autogen_b8fbc6f6e9c515edce3c7a438b3bc308b30d3857.hip +fmha_bwd_convert_dq_d128_fp16_b64x128_group_o2_psd_deterministic.hip -> fmha_ck_autogen_490a68220a7b621ae9817d7b77f55de239b0a4f3.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_batch_o2.hip -> fmha_ck_autogen_344932e2655d7b32704be8de9a63bbd8c3369f02.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_batch_o2_deterministic.hip -> fmha_ck_autogen_5a85ae0a16e4b293b549bcb6a3ee52df7fccca32.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_batch_o2_pd.hip -> fmha_ck_autogen_963986150adcd6e1d3886bacf2166de1252e14df.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_batch_o2_pd_deterministic.hip -> fmha_ck_autogen_8bd1a40b12ce927323594fcce61eb9c20cc5e3d4.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_batch_o2_ps.hip -> fmha_ck_autogen_296c5836ba118969c4ba89ed62a98dffe3105738.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_batch_o2_ps_deterministic.hip -> fmha_ck_autogen_6cfb7075345704340ff33dc0ef7c04ef127f26ad.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_batch_o2_psd.hip -> fmha_ck_autogen_22511de2592b6e350737e44865e1fed6496e3f32.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_batch_o2_psd_deterministic.hip -> fmha_ck_autogen_609f68180582384ba81aae2b1d4a4c52dde2c68c.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_group_o2_ps.hip -> fmha_ck_autogen_c9fe51f982abd60e567d4238d3266fb60e45814b.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_group_o2_ps_deterministic.hip -> fmha_ck_autogen_10a055e5c3d6a953d470db5dc21449766248058a.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_group_o2_psd.hip -> fmha_ck_autogen_327e27892bc57f3dec0da24f94f2a483d6c9321b.hip +fmha_bwd_convert_dq_d256_bf16_b64x64_group_o2_psd_deterministic.hip -> fmha_ck_autogen_c581974c8b6f43f60d0af29c350d850b55c03121.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_batch_o2.hip -> fmha_ck_autogen_01ac1a2ecf9a487809e46faa92e267df2d47de91.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_batch_o2_deterministic.hip -> fmha_ck_autogen_dbc4135fce01e8731fec7a78d0cc0fdeeae28b90.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_batch_o2_pd.hip -> fmha_ck_autogen_e09d9baa269dfbb30b714389d1733be51cc419b7.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_batch_o2_pd_deterministic.hip -> fmha_ck_autogen_5f71e663978dbcba859c5114ec675a712e343fd6.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_batch_o2_ps.hip -> fmha_ck_autogen_d257148f457557ea80ca56690e525db3a4b0ff55.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_batch_o2_ps_deterministic.hip -> fmha_ck_autogen_8e2c587db8bd9f1b551624e0cf8b67a90245d7da.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_batch_o2_psd.hip -> fmha_ck_autogen_8c13c4f3f645a2bb475eb1c55ce1de452f0e2332.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_batch_o2_psd_deterministic.hip -> fmha_ck_autogen_7b7fa76609243a8709f349ffc0d9d88157f28dc9.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_group_o2_ps.hip -> fmha_ck_autogen_2b3326e055da32cc979892a2fbd0f7b003cb9f98.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_group_o2_ps_deterministic.hip -> fmha_ck_autogen_671828f15eec2a58be23063a1a8132d337cd26de.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_group_o2_psd.hip -> fmha_ck_autogen_457eaffbff3c58183a656687010daa2c16cfc26e.hip +fmha_bwd_convert_dq_d256_fp16_b64x64_group_o2_psd_deterministic.hip -> fmha_ck_autogen_d18727988e47264b42b4153dc82fc1a750f08db0.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_batch_o2.hip -> fmha_ck_autogen_ab6cd5c9242f8278c8f3d9ce57b97d605c7e5a3e.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_batch_o2_deterministic.hip -> fmha_ck_autogen_0c93c65e5942a2f43f2e491547add02777dd2eee.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_batch_o2_pd.hip -> fmha_ck_autogen_d32c64ef01aa228277d031a74df51363f98aa2b0.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_batch_o2_pd_deterministic.hip -> fmha_ck_autogen_e5c5079636a4a31a849ce8a5af89d50330a74628.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_batch_o2_ps.hip -> fmha_ck_autogen_ea62567e9ea16771d8445464c38f5a2931cb355a.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_batch_o2_ps_deterministic.hip -> fmha_ck_autogen_c6e2da8b791d31f4ba05ef5f833fd6dea9e35f1c.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_batch_o2_psd.hip -> fmha_ck_autogen_f731289837f915e2aec1bd01eef1b3c1b099864d.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_batch_o2_psd_deterministic.hip -> fmha_ck_autogen_233132e712eba8972ba444c604f89e01c5b84cc0.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_group_o2_ps.hip -> fmha_ck_autogen_afc4b47a6fa62a4ca5cff6a7e01c9f6b371d2215.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_group_o2_ps_deterministic.hip -> fmha_ck_autogen_bec30e7107c5dce3fe6aa87d83ed96da75478da0.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_group_o2_psd.hip -> fmha_ck_autogen_f4658c32d562f9d60c5ca1262a2e0df2375063bb.hip +fmha_bwd_convert_dq_d32_bf16_b64x128_group_o2_psd_deterministic.hip -> fmha_ck_autogen_9545f95c1093c60f0fb6c794636f79aaeb53b733.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_batch_o2.hip -> fmha_ck_autogen_e6b53fb8d81148ff384d31a703bb4c2e7a5a33af.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_batch_o2_deterministic.hip -> fmha_ck_autogen_7aa14aa94d625b33df1adfa30ef4d91769592608.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_batch_o2_pd.hip -> fmha_ck_autogen_b5db3d5b1d8af89381fc4b8073f84c5fa25fdef5.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_batch_o2_pd_deterministic.hip -> fmha_ck_autogen_e8a9427f34bbf5ddb28a39161acc36806e68f2d0.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_batch_o2_ps.hip -> fmha_ck_autogen_724d1d4408196d611b2e0535bf8833652acbd6ef.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_batch_o2_ps_deterministic.hip -> fmha_ck_autogen_a3ac4f93722dc314086f1b7d7b8adc687cd75f82.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_batch_o2_psd.hip -> fmha_ck_autogen_377b70f54cb2778b5ce3df936b477f775eea8b3c.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_batch_o2_psd_deterministic.hip -> fmha_ck_autogen_5f20263fd84776f155519b3481be5e2c5b035585.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_group_o2_ps.hip -> fmha_ck_autogen_9745b04a8026a01828c5dd606d89d044d3ed1d99.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_group_o2_ps_deterministic.hip -> fmha_ck_autogen_a7784b03ad757d51c234fa86ea9891f055ecd5c1.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_group_o2_psd.hip -> fmha_ck_autogen_22105635385fbfb5d2f330df83ba6747bcb27f6d.hip +fmha_bwd_convert_dq_d32_fp16_b64x128_group_o2_psd_deterministic.hip -> fmha_ck_autogen_3afbb5ac9048a962a60f48886728220ae6c2aeaf.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_batch_o2.hip -> fmha_ck_autogen_429b82a27571ac91e3631cbdb7e0a58155abf962.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_batch_o2_deterministic.hip -> fmha_ck_autogen_dc818f3ce244743cb1dbff9aca399df90742a6d0.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_batch_o2_pd.hip -> fmha_ck_autogen_7f9403cb91d6aabebf081afae94a8ba397d8d24f.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_batch_o2_pd_deterministic.hip -> fmha_ck_autogen_ca5681d4e5871aacef74bdba9e368445875252d3.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_batch_o2_ps.hip -> fmha_ck_autogen_1e7d7888480b83c78833214b32e10f37a6e20301.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_batch_o2_ps_deterministic.hip -> fmha_ck_autogen_4018f690b6322588041bb467beabd8a7bc79a2e0.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_batch_o2_psd.hip -> fmha_ck_autogen_23047ea90076e3b0a3eb0586d49b9ee74ca6d279.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_batch_o2_psd_deterministic.hip -> fmha_ck_autogen_5a216f777feec4752f5882677b18168225da4b53.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_group_o2_ps.hip -> fmha_ck_autogen_fd19d7614f2ed5da21a52ed172ef62cc07c9c01a.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_group_o2_ps_deterministic.hip -> fmha_ck_autogen_9893336a4b00b2a63f23ed7e13ec54c82d9e5063.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_group_o2_psd.hip -> fmha_ck_autogen_131c1fdc4206bb952b2fea675f24e3b09f605eef.hip +fmha_bwd_convert_dq_d64_bf16_b64x128_group_o2_psd_deterministic.hip -> fmha_ck_autogen_cc4ac5a18f57f2ebb65f7e356e858ab0d59b2133.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_batch_o2.hip -> fmha_ck_autogen_dde93ffe7fca311e136e42fbcd12b05c9fc7174c.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_batch_o2_deterministic.hip -> fmha_ck_autogen_7b67045d438a7e4b8f3a313a5df5a85f351c1be5.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_batch_o2_pd.hip -> fmha_ck_autogen_9689ecd7bf51bcffe9f5002959bdda41c50a3c8b.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_batch_o2_pd_deterministic.hip -> fmha_ck_autogen_c41b6eda4f250da059fe0c428428219ff5a250ef.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_batch_o2_ps.hip -> fmha_ck_autogen_c45a5e40f6a66bc5292a56e0097c69fe37cedfb3.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_batch_o2_ps_deterministic.hip -> fmha_ck_autogen_ffc6056d9fe125a4dbe08c1d86354e51f7daadd5.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_batch_o2_psd.hip -> fmha_ck_autogen_2995d39cd62f20622a31f11a292ed175abb5fdf9.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_batch_o2_psd_deterministic.hip -> fmha_ck_autogen_cb10303a0b79f2710eb7c66896d3c1f8b12c04dd.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_group_o2_ps.hip -> fmha_ck_autogen_81dd3ea61bb61de02667b14f5a94198f48c7307b.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_group_o2_ps_deterministic.hip -> fmha_ck_autogen_d3af8763f289dace1054bdcb4dfeda28b0aefcae.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_group_o2_psd.hip -> fmha_ck_autogen_e6e6b10e73733716e71ebf5a53703fb935fc5e02.hip +fmha_bwd_convert_dq_d64_fp16_b64x128_group_o2_psd_deterministic.hip -> fmha_ck_autogen_e75c757c67aa23cb88e1aced6fcf36b7b28391db.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp.hip -> fmha_ck_autogen_2b3af90387f1d227119c5dcd4b71362940bbce52.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi.hip -> fmha_ck_autogen_e3015c5d50481547aa5754d042d9d7040cf1c7ff.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_deterministic.hip -> fmha_ck_autogen_a4700d87a19a173e84d64e43cffabbed52366e35.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16.hip -> fmha_ck_autogen_6af4c15a119e805e4407b184625f57966f8833d9.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_8b17c082f249649eca733a8f0cdf9a1205c3e3d7.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask.hip -> fmha_ck_autogen_226662cf1c9900a4334d2cadcc5f5ac3ad355f05.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_deterministic.hip -> fmha_ck_autogen_d723b191785c97d284675f700a7baeb52a2eb791.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_afdab954fd111ec48721f25710d61c0c8affd8db.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_d54ac01458df3f240e0656d82330f9de23ba9651.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_deterministic.hip -> fmha_ck_autogen_4ed6da5357b67cc28aee4afa9523adaf055c4e32.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16.hip -> fmha_ck_autogen_8c3bd4e029bba76ebfc79e6522dbc8ca0bba5dd2.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16_deterministic.hip -> fmha_ck_autogen_dbde2ef18e2174ebe13a6e7c8c2a6b05a6612047.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask.hip -> fmha_ck_autogen_c363ee1b087f6b504a3dd3972b96e77db02b0582.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_deterministic.hip -> fmha_ck_autogen_a02a71fdd587e47ee68e0cc76c3c4494ce06c359.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16.hip -> fmha_ck_autogen_968fc75a7d102aca068e3ceb6111728c280fa837.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4a06b5b153ea6e8b1e20d9aad9d4633333fd98f5.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps.hip -> fmha_ck_autogen_bde24a8dbe6add6f2dd2beb48b1280f3a84a9b2a.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi.hip -> fmha_ck_autogen_415b183c50dd2663dabe3eb8b780913b778c54ab.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_deterministic.hip -> fmha_ck_autogen_258d747083272ea657604ac84867ecea17bd65da.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16.hip -> fmha_ck_autogen_2a97c457144cb63a9c6c3d6be613b47bd0df9928.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_8d7549e66ef309e32779ddc2a1f14e79bae53754.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask.hip -> fmha_ck_autogen_5cd41b6f578f3c903eb9d58ebfab62eb296044e0.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_deterministic.hip -> fmha_ck_autogen_dc34b6ef496d4e0d8fbbe10731d4a7b1c136c036.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_4a9f3da698a6103caf25d785928dd9f814ac27b4.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_aa996b9c843200a2ec33ed4319b48106cd7c6384.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_deterministic.hip -> fmha_ck_autogen_1d02609fb803ea2697e2c2cef35e6f923d2578cf.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16.hip -> fmha_ck_autogen_58eb2edc7738d8d18ac359691da261ceaaf71788.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c35ea54eb6cd0f3756c462c66d9be956279b46ad.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask.hip -> fmha_ck_autogen_0f0c699d9c3b0ed62097e38ba05e40e815cf474e.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_deterministic.hip -> fmha_ck_autogen_64fe2db75cb20428856b02cd1cc8d7b393a6ad9c.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16.hip -> fmha_ck_autogen_38b94d76503e13c911781169fbc378517332c42e.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e0966fa1ff013e477b1706928de6cb7f8587c154.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk.hip -> fmha_ck_autogen_b9559dd36a0a4f5e068a722e285f485137bd5ef0.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi.hip -> fmha_ck_autogen_5a05b4e7782bd0e29ca9f6d33fc59d4304136d41.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_deterministic.hip -> fmha_ck_autogen_b9385db12001110c42eff6aabad935a69ad3afe2.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16.hip -> fmha_ck_autogen_c1f721a330b2d0fac13b22061616d7b10c0f91e9.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_37fe04467e87ec2110f60c7aea0cc9bf2ca07481.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask.hip -> fmha_ck_autogen_d4b99af9a573df50a27fccbec3fa8e350f1854eb.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_deterministic.hip -> fmha_ck_autogen_20588bcac681a5d69f252d7523a3681a0c6b6181.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_a3709e4fc53d2254a03ea7660b8c72d2f47cf1ad.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_47fe73f04cef91cd2a0682e905483968ff80eadb.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_deterministic.hip -> fmha_ck_autogen_ad9b99a194b59d3149842c15733394da275b12c0.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16.hip -> fmha_ck_autogen_062c8c3c1cf6c33af4574099e9b6ac54a55ad776.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ab1ca4ce061f7f69a250356f613cab00d1e2ac71.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask.hip -> fmha_ck_autogen_cd4efcdd12184211c74e7b3f2f30fecf1041ca32.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_deterministic.hip -> fmha_ck_autogen_1d0b822743e0205f60521d38d7c64f589fdf0f58.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16.hip -> fmha_ck_autogen_e10f47a44400de385ddbeb99475b717c5646fb41.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3e562e6c3af28b8478020ce3c3bf73c036001c93.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_1a99b2625adffa8215276bb88fc65bae944b846b.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_56cc4399c5567a9495f17d54c712cc9e65e57521.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_ba8b09f0aaa40a7c9ad5f0458b460d3e328f3c74.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_43e7c78e8f65be35e2753a0ad5123118555c56b2.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_bed5a8c5cf683f6dfaefad72c2e2f5c2f2b2732f.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_73ec21ed6e040260c4f04ef68ef9307aa86985a7.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_3642b78913a853a62dbff8b99d9ae3fa458f461d.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_ac5e9aee85cd16903bf7b82a4ac10402b0b26e22.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_5f954a393b7b5a7131c13d0c4578443f468a738d.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_78e945db4afa1330fe3978bc1bc9ae99828ae287.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_d4aff499ad527be5fe33b8e92547df57af26d40d.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_2b8169ce4b4b9a17ac96fbb232e6a93f22071ab4.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_0a89417a043556970f72eebd48b4f3e7ac15377a.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_4824e1f8cda50f80988857611da766685da94494.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_dbae1670fac6812b2d2cbad973e4b475509ea504.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_5daedab8931f2eefb649b91e80145cb71b63360c.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv.hip -> fmha_ck_autogen_dfcd68acfca68d1acac94f493e25be0ef20f209f.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi.hip -> fmha_ck_autogen_3511c54e6a6f9eec378d8b661121066536195d3a.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_deterministic.hip -> fmha_ck_autogen_deb9ec2cccab94920e40f62a1f0f094acd919d07.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_0fbb0bef3b388867e75d7a8a187b8b4b650a42ae.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ab0be5a2072b5e87f5ee58149688796b6513219f.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask.hip -> fmha_ck_autogen_91a6200e36944b1f11106c02f7fcee053f01ee71.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_1f81f8cce0d77dec9f977b9eeb0778b70a13fa75.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_bcd7ccdceb7baf3b986f2a0248827822a5f72e47.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_58762476c7f2bb05dce92ec22c0acbeb03676746.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_deterministic.hip -> fmha_ck_autogen_f4df1cbfbaf67705820f125b474469ad7ebab0c0.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16.hip -> fmha_ck_autogen_f42cf0e5fe479690883507028748b0cd3dc83cbb.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f682399cd6412fed6a1141296a7e4d42078f7b29.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask.hip -> fmha_ck_autogen_256ef175029a43e64164176d4eb212baf9d27bb9.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_deterministic.hip -> fmha_ck_autogen_3206cc121ce8955ed59ea3b12b858ee2e0cf82f8.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16.hip -> fmha_ck_autogen_d1840494c4fa78ff399c0399b3ad7ca3d22d4587.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_31c4b866692ba5c3d115482bef4790733863c1fc.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv.hip -> fmha_ck_autogen_b5c7fca1f76a31b0390e92d90d569fab94d4f783.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi.hip -> fmha_ck_autogen_dc3d625c5ad3e871f5a727ac946df642d988b9ab.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_deterministic.hip -> fmha_ck_autogen_ca4c6ad28aff1976c6dd36974ec3b339aa3090e9.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_eac353f963c52624cf79e82cc2b2c02eed94b677.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_29bffc159b0bb826ba489ae763dae141bfe8e802.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_9b327f0fa1155f2235d76be45cd22e3db5a69429.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_d0dd0165ee91c095a19ceddf08789e3576912590.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_5344427df3ae9392c4fc4c25c232196828e70648.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3f7315955f555768f24585a50d75e216c40f062d.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_deterministic.hip -> fmha_ck_autogen_dbcea8f7b5930abf76eecefce92d0db785d2df5d.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16.hip -> fmha_ck_autogen_165dfb45658df8f1ae8dc0738ac9614740f2576c.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_8a58d4bca33c4c0e79141a56688049237d170d1b.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask.hip -> fmha_ck_autogen_fe9d98dbec5096a89b116f85675af772f023014a.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_deterministic.hip -> fmha_ck_autogen_d1c0dfd19a08d61586758091370acbdc6f267017.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16.hip -> fmha_ck_autogen_960ecb3013071fb65f2d5ed4c947c4bf303e5308.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_1552dc38d26f6badb7a9bcb5ce9124d54cc45ed3.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv.hip -> fmha_ck_autogen_3af86f458fb4dfcceb7db3357fbae0dc15142a15.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi.hip -> fmha_ck_autogen_74ba59d347ce8916a22b40e6f22a3c89e13db4d0.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_deterministic.hip -> fmha_ck_autogen_7344f96bed2f56793b1c2583485aa161cdf30379.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_ad989d2ce769f20e175fa88f4082c1c25fe03062.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_096e888c52d0f4a5847d7515fcc66208b1ff40d3.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask.hip -> fmha_ck_autogen_7cbe4562c51d6829ec5942e11035c452fe318b3a.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_621da34ee666903307d3a09b7a032f2a70054759.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_c64f4cdce32189065362a502105c31bd2d9d99a4.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_987f00dd759d9714693e7517dfaa8bb427294d42.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_deterministic.hip -> fmha_ck_autogen_1c2a2d78176e3f0a78e3ad78217e75a4430c0de5.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16.hip -> fmha_ck_autogen_ba145535e53899fe127987aa854f81234a9c51c4.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_0968cebd81ade762c2f92fffc0153fa7a2b91eb5.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask.hip -> fmha_ck_autogen_b41735d250b5a16967281a5f07873b9cde3df4d6.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_deterministic.hip -> fmha_ck_autogen_fac5a0f98b94530befd634891e42c424bb86f0e1.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_ffb8adef0cef91a86f36872407fea35df90e8f2b.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_78e1edca5abe1bb3e7aa946eab6484b7bed806a3.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_88ed7f650c958a644c8031aeb88688b1e42458e5.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_ef2ebb4a86e7ed0001de9c5e607b66fe8877409f.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_f3ff73f82aee3184849d04c2364eaa45c6d0de9c.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_fb9477a613665cebcad781389ba7c5a36f51efe2.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_21f860d42fdc2cc6bd743d53ba546e332c22fedf.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_55ea83a47c6299fefa4220ed88f7a8e1dd938215.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_80987e2d765efc320eaee813607c94c80ee35aa4.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_288458c5a0720ef152848713119ebce6d76db6d6.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_d6149eea92f2c40c11de3b778102fcf9b6a006b8.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_7b5680f97836be4a369802e8115617a83875703e.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_4347e039c003489dd528faf5d710e687321a3fd7.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e5b2bb9f8466de1ad5210e4c39ee7b8ecacdffa9.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_bc6ce17223d8d83a64b8c96ac88223e4441a4692.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_fc1790325b59bd44b0a5f6cf9723a25fd845cba7.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_662767e588220d0dc6137b00cc1d8dcc91e97134.hip +fmha_bwd_d128_bf16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a3dc780b17152f696f9b957432c2eae8fb16e85e.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_9a8e04fe9432a60f86ff0369e8c1851821074a04.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_835a906031a258c6362313eec783678bd8125c91.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_ee8e709eec7aef1fa681053c6d2969a5ff18c45c.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_8d079c1eb36db8461fa8b861c56760afcd97cc34.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_64b3488ddf3bb1a4870371882f0a5d267bdfdf73.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_ca3975efd767ddf7c12e308d948bdcaf0968493a.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_82ad0c0580516485ea432d98f53e73f6dfec548c.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_4306c6c37cf472ad262f53941611b5e60072bdf6.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4904c5910a2d0595b39a3f87652a9d1ef4fcbe80.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_f57f84892e2a8496169b7406e63b0d4f5aa63aaf.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_f24f26e45d5cf567d29fbe375fbf8abdec39186f.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a85d35b2fd98742427930eb536e346ffb005edd8.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_19df4e13108e043361e9528b71df56f04f696a0c.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_dbb06b43d5d65429e23cc717448cf1fffb0cfd74.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_0ef9b9413697d6f4573c6605bff6f58d027c5016.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_0b2efefea81036641561bed80c75d77651176f74.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_27c2000d32c230a57a6712f27bc0fba02722f5fd.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_ab1d7f93427095e39bfc1d986b3d7fe54073ec75.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_7dfe21ee27f8a0ca0407ef0dea73cd73ae6940db.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_8007bf7ae1b71bf8ac4a793aa519ad333aa7a7ba.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3937d9dfb68351de2942e32f35e2ca1ce71edfa8.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_de1ff66d2aeb47d2fdccaa4bb6b9d066b380c99e.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_5403eec1cdd216d5c4a7ba977e2ef92a0d7fcc8b.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_358399e756ed5026baf3ab78af17489dc07b9532.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_bd064e302ff5b983dbdb4ccf51383fb29ddff44f.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_c11d68fe766fc753c657362673704005b538660b.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_fbea85b766bf0c918ee0baf24dffc6a5563d5105.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_92f9ad0fb65638cfffb3e7786f2cbf01d9585b23.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_0a55ed15ef58c941e06dda890aeb530e28eb7bba.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_df4bb75ca79f805a81fbad750ad22f6d22b0d8ff.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_7ab03a62e064864e1e9c1cd506c1b2e1786a777c.hip +fmha_bwd_d128_bf16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a189292c81a18d21a2921ce6740f81ebf4c046ad.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp.hip -> fmha_ck_autogen_c9312d7159369d13f3148a6f0882dfad6921ceec.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi.hip -> fmha_ck_autogen_0cdef49859c80c6b3ba18eb2fb4c35c72abc1cf2.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_deterministic.hip -> fmha_ck_autogen_ae87b1d5c50606430b544ed650d87df24366e7d5.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16.hip -> fmha_ck_autogen_0a92671b6ea99891c0d69b1c793f4d131b9a82ed.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f4a6438394dd3427f29aa0bbe58ad1f797c3c38d.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask.hip -> fmha_ck_autogen_fa85f869a92f0482605e52019828244b12e12b44.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_deterministic.hip -> fmha_ck_autogen_c2541b6b5cf27de3f45f60671d36602f07ce1783.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_0595316f0dfffda03e5296b959a49ec3f3c48d67.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_fff7aa57cca501f221077124359a589b3a6f9d0a.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_deterministic.hip -> fmha_ck_autogen_358d28c958c0a831a615a4811d13279b18db09c4.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16.hip -> fmha_ck_autogen_96f1bb85dff8c97846f6b2e8796a6289bcd0d9d3.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16_deterministic.hip -> fmha_ck_autogen_14d4630876785655bd4950566e81ae0b645c0d3c.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask.hip -> fmha_ck_autogen_a48843d844f78690c7a45b730652f0f763c595c7.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_deterministic.hip -> fmha_ck_autogen_3e143d88eaa0d9cfea856b2f3a57d1275a656627.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16.hip -> fmha_ck_autogen_18ed7195a9443c84956c3f32839cb3ab9056bdfc.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f7035f4bfd8f2f427720a07e3c311bccc1dba683.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps.hip -> fmha_ck_autogen_f87790f260630f312b84888dcbdf849ce130ae59.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi.hip -> fmha_ck_autogen_fe97b7adcd67ed9bda8831d1f3f1ca7590c6d251.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_deterministic.hip -> fmha_ck_autogen_b41a30092e8138877c1f6c25656e0f8ae2c2444e.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16.hip -> fmha_ck_autogen_af06c0dae15684f83e15722a4c07342af9ea011c.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_158d5ce564c3ae1eefb54e3d41dde2604560ef4a.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask.hip -> fmha_ck_autogen_49f5017cc0f5c8c8dc71492e7765cf729c1f225c.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_deterministic.hip -> fmha_ck_autogen_280bfced8745fbd9266207463fb41476dc23afff.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_eca613eaa8471ad7da66d2f8f2b8e07f6e02b467.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_8e1b48a28b71c7f4c78eb14321b39951a7c5e903.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_deterministic.hip -> fmha_ck_autogen_a1d6ad9de7ac7993ae1923a2ef070b7dacb8c563.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16.hip -> fmha_ck_autogen_04641230fe9a50a221047f7a1df8a370f72805b9.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16_deterministic.hip -> fmha_ck_autogen_bc1ae1dddb8cc5d78196da6b26ebe66c1ce7e567.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask.hip -> fmha_ck_autogen_e8d9b65558398c0c10127b560807578ef117d7ed.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_deterministic.hip -> fmha_ck_autogen_87e3a06266deda093bdf28af82d8666066157fc6.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16.hip -> fmha_ck_autogen_0a672fca51de618e3441cf8764e8e83eb782f2c7.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_92d841e6d783bb46d841aafd9027f92dd1b61b88.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk.hip -> fmha_ck_autogen_01f74764c3c3284fdd1b67d0ea781c2261ed0de6.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi.hip -> fmha_ck_autogen_feb5e77111fe1e20bafdb83a925b5faeeb6214af.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_deterministic.hip -> fmha_ck_autogen_26d77b228420a3ead919474ec9c6fb2800f86890.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16.hip -> fmha_ck_autogen_4fd34faa8b168e2ac7862641229e6146d3e28aee.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_5724d91c1fd6290a6cf8d52a3801ac6b921dc7d4.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask.hip -> fmha_ck_autogen_dd11806cd2d3ef1127f676b2d98bf8fff2a1e5ab.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_deterministic.hip -> fmha_ck_autogen_aceb0641213e9a45ba48bcf72bb23845720d8b79.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_c0338fbc05f86270ded7df2bd3e2758a03961b62.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_2e8b4260626beeac76c26dbcee3cba1457b30e99.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_deterministic.hip -> fmha_ck_autogen_4e0a88ccef04e81b8c684b695f7cb4310e448915.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16.hip -> fmha_ck_autogen_6f31b3345893eec8ed1ddf1d8de2512b46ff6187.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_83d920a76114c63156740ba5dd6f3846c4b21c28.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask.hip -> fmha_ck_autogen_86fa51b8c7a2f3fac5cf4cd2951ed2ede5c35450.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_deterministic.hip -> fmha_ck_autogen_e7b2eb64b66d46359fab44333c2c484f4c9dd5de.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16.hip -> fmha_ck_autogen_81acf1d17650712b71a499bb66909bfcfcb6aecb.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f1ecc90ad7b86791a9e6f73a582aeff30f393804.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_f01468c62c878295443981662e037ec5213cf7a3.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_e2deafd2f36cee29109fb824e0135407453adcfe.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_b1766695dbb790bd614b83dc7569ad449404cc89.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_784c35fee4d372123631312f1051c43e1fa12378.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_38bb367362fe2c4849ded728ec5dd00969ce188f.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_9afe4b6f3b901ff4af81bd4f1cd8ff19f09d0b07.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_9ca3b1d36d777213eb381b47871bf15dd163c994.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_d5edfe3e3dc3008b928c8e6dbd50784b905f189e.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_10c24f1f9009e46afa3a59193784cc2575f79056.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_7dac5d4cf103d658e129673549549f1276f134e0.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_c8dbfaffc8a9b573f194f9c63f1175d9725f8950.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_566e26d4969bc6bbe9b092bedab11cddb3360c0f.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_0ef309b923172f4c0fb38d9b9f5325b33b4877c2.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_3bb3b682eab96e4e173affad75b9d8e73f1dd690.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_f92e9a82c879051d6fe3c42108f8a574187704af.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4f44435491aa68acb3217b0e693232c67641a2db.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv.hip -> fmha_ck_autogen_6082d55544b5280b49b071ea277fb1827193fa2a.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi.hip -> fmha_ck_autogen_81bb8f13b6f20a72c9ce6d0b53f81eddbf05f1c6.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_deterministic.hip -> fmha_ck_autogen_1e42736d4f677a59a172bd6f162616a437696351.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_b9ed0a64deb55616646ea98b21a891c971cd98ad.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_fb2fbb135d59028afcf867c2cf08edc323565528.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask.hip -> fmha_ck_autogen_6360621af3f7e1e81a8be48fea8d2750fdecbbf4.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_097b3e1dae9bfb2e89398706508f8e01966fd4ea.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_4409f2a7deb027e864afdfc9975d3ab93c5dcc9a.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_6d307974bdeeef95cca0d130ebb7aeb77fb1b6eb.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_deterministic.hip -> fmha_ck_autogen_01ee0083f6df962c4a754cd3295b1a436c590a0e.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16.hip -> fmha_ck_autogen_c0a3c4ac0a50bb9b7ad764929dbee98c856b1210.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c27b3026f1dc3056dee3a3e64bf31c45683607c9.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask.hip -> fmha_ck_autogen_5af96b404feac271dac8f4190180754480d3ba80.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_deterministic.hip -> fmha_ck_autogen_f69878f4ca8cfe6b8d8748766f66a1ef8eab20ad.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16.hip -> fmha_ck_autogen_8689126a7eb09d81baaf8f99dbff8932fbeab3cb.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f6856ca950bcf173571766c3f04de4163be0402e.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv.hip -> fmha_ck_autogen_d036096f49a89730f8af7e75457c88cb8ae64165.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi.hip -> fmha_ck_autogen_03ff035717140f7385282419598cb4fb2881ce8e.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_deterministic.hip -> fmha_ck_autogen_de85901d66dc04b1143bb6404445baf65693b781.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_5c742b9ac6749f189d597ac97d46d35189472c50.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_bd9c47f3305e47db6ab6bc627fb3d80269633074.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_d82773721479613ad72e334510a248f1436b38d6.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_cfda56a4eb08b803332f25bda6209932d9624acc.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_328a311bafd1c153525393b252e4170f8aafb370.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e5935fbda313d3518f142f43d46f56c600f69286.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_deterministic.hip -> fmha_ck_autogen_48e9e858abf6f77489f3fadc4ee81edacd26705a.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16.hip -> fmha_ck_autogen_f71f96ce4dcc7f789a8ace73c230c203b05ff6dc.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_01d12033d59ce2799a2a024e5d9232325ccf1320.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask.hip -> fmha_ck_autogen_5854f09511778dd1779a839b0b194896070f69ad.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_deterministic.hip -> fmha_ck_autogen_7237ce5f3cf13ace3efc0b0227ae5a8c1fdfce1d.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16.hip -> fmha_ck_autogen_3b4ecb47f9ebe8c2784976c3e9bbe4834b475cf1.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f18c74becc24a93427d9c0838784e9b6caad6e81.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv.hip -> fmha_ck_autogen_c4c6c405cefe204824e8fad1b3dd34bba87e796a.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi.hip -> fmha_ck_autogen_41db3f29d1940e59dadc357c040ea37a6ff208d9.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_deterministic.hip -> fmha_ck_autogen_df4c9eb48da49a61957537270d94e56cb4e426be.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_6018ab272d7306689c7dc5a6d5326efea1471235.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a421c2ed6b295c458071f1988b9d6f7b46e8992c.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask.hip -> fmha_ck_autogen_61a44ac409e914c12281f1d26e5b52d8bfd0df75.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_7e332a6aeecfb12dcf70c69157fd3137343fb9f6.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_2e43e401abbfb1b6737e4dc822f68421abbc648a.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4afd02981f92fbef6277c1985cc479c12bae9239.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_deterministic.hip -> fmha_ck_autogen_8513d96a66a4d9fb8dfc84afba7e1d8c200248a6.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16.hip -> fmha_ck_autogen_c4dec99707511cebd9188d216ee0a148d729b470.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b75843bb13058ffe29251e053800c509c7590544.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask.hip -> fmha_ck_autogen_6eca9cd905ea8b0454cf9564643894682b08cb97.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_deterministic.hip -> fmha_ck_autogen_c4b34d3cb673447773f6da23e9cf52b98e99f718.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_fbeec221cd63adaedceec39db41ea942f99f5133.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_2d7b637e0313cb423b22cd8844cc2997b3ff73e4.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_8fb224b40a7be7db0a9c5c08cc5ab05b526c14e8.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_e28fd64c2f2b27577109a984e6ab82f5f0fcb296.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_2eba937ff6d0302ab013db7349d4feb914107f1f.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_4e79dce18e49ffe024fe4cd0693ad3399f5edaee.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_8cdcdeb845e7bcdb89ef70ab2a97157d4db3cb52.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_30024440e780fdf9ec94deccc85216d8bbb5788a.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_c1f40c3421b9ad8cf43940530ec50bcf620058f2.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_7e89f79217037e361bb0909d06534e40f5026b4f.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_44564dddf8b492d80be54854abb8d1d831e42679.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_7831ce329f2a0812ebb1dd103ea4ba8cb7ba531d.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_b7a03ab0b7887cc7ed0cb40e56360a8d36c0bb8e.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e986d5f8d5591f3e0f1cdfad19c38c420fd93023.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_076b3beb57b30afb30636f948e3989b346b38d20.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_2177d95cdf45f6fec95d1812f2ef183a75259e38.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_ff6862dbdbb20bc63a650e1f93e9ac169bb702b2.hip +fmha_bwd_d128_fp16_batch_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_649336d59a8b35919e593217b6fd4314a04ea359.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_14d11aad7b666f500f68b264a2fcca6dfc5f1a05.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_4d5f3cf0f78f73df79665c26b20b0805615e1b04.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_4bc48576f285325345fa1205e5e7e01787b74f71.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_e7c0a99e949baa5f3a7ee2d6e84427982f82f76d.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4a2e6b05e7e4de2cb23d815f8b2c8adf22131c0c.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_0842c4e3aabdf55405b3ce09ce1899245ddf11ad.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_9ad1f99284aafc8d7908d062f179a056eb314925.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_b80d0828ba6d24ea3c1a97bd9835ee937b4b32fb.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_847feaf237911478173377a501ee19ee325b012b.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_1a8da3e6ab050262b659c801ccf9a14787d7f176.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_0225857454eaab2eb664aef7a0849ce12c32fdf9.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_80a72d70d80b66c19e85daa00497308381050048.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_8b9043572cabb65435627a3faf23b18d039bbcd8.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_378759ae25465c32960487375828e23c5f1ac869.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_83ddca2c6ecbba4314c434e7471ffb8fa642f936.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_40db688a9189e1c47c300d474df946a248a63303.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_ad091c69d19b27f7ad50ef6311532ad8b642a9c6.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_5e735b12d130ebf849ac5d6752e413ecf3e69fbf.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_2c77bd7e89ed832cc31b2995566a49bec6e4cb52.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_133c51948cf8584900807998da14d788039f53b9.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c29110dd501853e87ebc122dd1971b0bb1bcd92f.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_9a9edbe35a8fac7796f00bde836bd547044770ea.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_ccac6c0e61b65c9422c7f30fbd979031698370a9.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_0d13a4c8d169877da6408584dc1f20a6f7c5e3aa.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_744ec604c577a27e0aae5b39711a9e2eb82801b6.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_43f2156a04b18bab55af60e9357f28d8a4604e8e.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_dc9e54273c0ea2358fb573a7d918aa7b09fe07f9.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4f0aded9d1baec3125ce8e176248cb146ca580fa.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_c80dce1a17d073259250ec0c87ade69e639ffa8e.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_e307a1b0d5a8f94e0a0f4032f401d20b4b643523.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_05538339c21c92c53d237865d72debaaf2ee5075.hip +fmha_bwd_d128_fp16_group_b16x128x128x16x128x16x32x128x128_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ec3deb1382003ac010d9bc1c59d1878d3ec7a727.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp.hip -> fmha_ck_autogen_1f7faa0b33a9aada86f032174afd40d18efa7715.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi.hip -> fmha_ck_autogen_4462b192a64efb60d5484798526278ac7a0fb9fa.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_deterministic.hip -> fmha_ck_autogen_3a2643099365d0903c799585f41dc1a525ac9f9e.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16.hip -> fmha_ck_autogen_555ba79201a585bc091ccfc326fd24e851d1eecc.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_153e897098539c3466da9d7a37234daf16476277.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask.hip -> fmha_ck_autogen_38a5ff72f22e0ad040a281e66b1aca0bf3a2aadb.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_deterministic.hip -> fmha_ck_autogen_4b2e7f96b095ebfb66ecc7a75752fba2a63e4f37.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_0fd4068ea93fcf4df463e3bf3a6898d23b65da7f.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_2b823c3b99e7c8d1cdc39a5dbc7365a383bf9ccb.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_deterministic.hip -> fmha_ck_autogen_3824e97d5ecba46e06d5ec1a9456c810d80227a3.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16.hip -> fmha_ck_autogen_a5d4eb673bafd81e3a0ee213da4603d88b8460ec.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16_deterministic.hip -> fmha_ck_autogen_40aa64439b80ff8dd12498b3e5f6b625da16e285.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask.hip -> fmha_ck_autogen_f3bf7ef503bb026258b3ec3d82d3ef1443046964.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_deterministic.hip -> fmha_ck_autogen_556cd05288e1666f5c67fb87ad02ce660e4c589c.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16.hip -> fmha_ck_autogen_fc030b61ae20c4b7d9b2d10930a17e01e9e93328.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f069b38b26c30bc770f74c856e47eb498f5818e7.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps.hip -> fmha_ck_autogen_fc7b0916744b593435d8e1e7b6d874d760cd5e3b.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi.hip -> fmha_ck_autogen_abf92a5314fd33491b5eb6ebd2418b7e0d5db774.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_deterministic.hip -> fmha_ck_autogen_d41b6a64dd181f2efa65aaed03a3d229b3566c1d.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16.hip -> fmha_ck_autogen_80bfb0e6032892cc58cef4dd403f305a5b76851b.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_8301bfc0394936a68fa0098580f06e77c88ebed9.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask.hip -> fmha_ck_autogen_e9b53fa68641f45baabf40b7cfb8b35a9a1b9c7f.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_deterministic.hip -> fmha_ck_autogen_c9fb8343e623e46f01893a2b61345d1ca5928671.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_320a6196b662a1d3dc7441a9536d825dc356b95d.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a3d7aa46528ee74e2bef1e87c1feceacfa55e173.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_deterministic.hip -> fmha_ck_autogen_c59937be2b9a13d6520fdcc922e4e75c9fa085ab.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16.hip -> fmha_ck_autogen_e477abef05ff37ec27705eda51896e2aa3a04966.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3da8c31f6d5bcaacfa4a21aed4d1d3caecb48922.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask.hip -> fmha_ck_autogen_6d40d762ed576832b3a752453e9881b5fe6d2650.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_deterministic.hip -> fmha_ck_autogen_3c1454ffc1418dac641f63671e947d9f550b1f0c.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16.hip -> fmha_ck_autogen_96c129dd4c798343d6f78ab78056f0faf2f1c9d3.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_242013527a0266ad479715ee3e6ae01c45de29d0.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk.hip -> fmha_ck_autogen_2dfac5a83def98340c8786d55a30a98ad68b9eed.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi.hip -> fmha_ck_autogen_ae51b30c7e1cd30e550187458350c8db7c59a9ef.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_deterministic.hip -> fmha_ck_autogen_5e0abf4e2b6be3e2c555c2134705b9dcaee617ce.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16.hip -> fmha_ck_autogen_7309c38fc8a2d5ad6efd449107dc54a7509624fe.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_245d90000b55ab8b6055b1934880fc6c4870b34b.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask.hip -> fmha_ck_autogen_0b9585ba1c10acf67115c5899b3546608541820d.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_deterministic.hip -> fmha_ck_autogen_8e431313fe082958d31b68d2fd0d61df0fe56736.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_1db03461737f1e359f389a8d297476f9b60faabd.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b4b037a2e262d11d3ed7d9feeb41b9e05427a739.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_deterministic.hip -> fmha_ck_autogen_c919b8ed877d4244d01a17ecb948b459e361ff24.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16.hip -> fmha_ck_autogen_64cf03c0aa3f1b2a7b76b4e3418eb5063b982a29.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_1386cd75411e61a8dbbaf2b916e62f4f5f99104f.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask.hip -> fmha_ck_autogen_6e8cda718e10824956f0ee39bbb0891eafa45a7b.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_deterministic.hip -> fmha_ck_autogen_2ea394a09c8691a534ad2219bedf73724b6dd5ce.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16.hip -> fmha_ck_autogen_748a3d76e8ab73af9a5d2302d33e3b1d1b866dd1.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e907e8d1089557dfcc95a05160be5092e9119a53.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_c4c3425fe683d35dc3335db77d183ad1620b7a92.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_d04dc4ed02eb42c3fe303342801ed3073a0dcb8e.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_3ccf0a9d5a5451da5dbf6075ccea45e4a140550a.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_d924ee32b178b6bffa7a71603d6e2818f66177a5.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_aebd5fed34ebceb879ae3dffaf58c7c04ab5fe80.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_5939e6610e41aff8d1ccdb66d9e84d3e48e8d379.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_b4bd2d206ceb237ed2c51f58abb5cbf96e39d07b.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_e56757fb17f5e94a6ba1fb14540a68c36d571159.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3076a6de0e2612279e0ed64612f7393856bcc9ac.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_ea6a6d4cc262ea838dbb83ee747112f95fa297bc.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_1a6bc2762b95d550485aa720edaf71138d94cd07.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_614a9f10ebc51bde3f580ef527c17f89489c12c7.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_0271bd8b7c270e1593871b638288a4923342c446.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_4b74439f42140cdda9bb0f78d995d741212a35f4.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_d733f4c03e338ea7c6d8f759c1132499bdcea059.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4432c5214c4d40c54ca2d02f0d4785c6d6902370.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv.hip -> fmha_ck_autogen_1f13a6d0f8c798c0c4ba4ad202d081899fe081ab.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi.hip -> fmha_ck_autogen_a1c71e7d33f0597fe090a3524e33e18b2e562680.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_deterministic.hip -> fmha_ck_autogen_e13b86fe4e153e0bfa8d1e75f3641fe32b0c5149.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_adae2d4f8b2dac799e03ea6f279e6ecdf66f5381.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_70586668a61ab88bc46b763df8f1c2ea52001ea0.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask.hip -> fmha_ck_autogen_82f0f3d71108dcc49234a258f0f3b21ea2123cc0.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_1de2f97d49f015b9af0b186801e939c6f357a0c4.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_9dc424f0e192155e3c4e786e5b87d5a1a3e6c4ad.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_bc744db85d4237ee9640f1658e0caab7648e3bb6.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_deterministic.hip -> fmha_ck_autogen_e8d8fe5f4f8641998b8b805a20b2ca92d019ee59.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16.hip -> fmha_ck_autogen_549b6956eaf678f7eb901567d1a515eddbedae5f.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ddcb1cfea1b0dbe50a02252cba99428fd977527e.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask.hip -> fmha_ck_autogen_86d73393d0d8b769f30222f7817563a955c36dfc.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_deterministic.hip -> fmha_ck_autogen_249668a3212cd00edaae871758be30a5a1fea589.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16.hip -> fmha_ck_autogen_643b3798f11997d33ccb58d90ed6c10d5411b735.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_adda7ad787524e3e47dcc1b65c41b2faea38f55f.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv.hip -> fmha_ck_autogen_5d7ed4c885fb32a0b548186e56d64bab98071d30.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi.hip -> fmha_ck_autogen_77a814291d8f01870274149b9d82fb75921d6e20.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_deterministic.hip -> fmha_ck_autogen_f395bec57c3b2e6e169134dd8d20b287d7405134.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_0f588dcb2ef86677ebf84e406eb802e9921d1f1e.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_0caeedaa7d50f1741d618fb6c573529eebb075b1.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_1e33ce1fa113b221e5303b4093c2c4e748ce8298.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_ee974931e65d6b16b7c868d462b95dcae20b7513.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_85960fe542635079de5eca3c7785890cd4740005.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_089de13222caec1483207d4a54249f8da4f9c151.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_deterministic.hip -> fmha_ck_autogen_ffb5b7349a671b182d73c8016590f26fe06a4cba.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16.hip -> fmha_ck_autogen_768c80fd3ea17813df1bf19a158186834fd00780.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_597a0276ec419f18f060a5186e6bb703ae434ac8.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask.hip -> fmha_ck_autogen_fc86c13e933cba40553ffba31d53aad27415ce4b.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_deterministic.hip -> fmha_ck_autogen_31c3760f5978baf9780ce4587ae4c768af0e49d1.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16.hip -> fmha_ck_autogen_92b0770fe64e3c60b9e56170aa88bbf74802a813.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c487a1a9933239270f44b1e08e1cf5323521c089.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv.hip -> fmha_ck_autogen_3a1dca5feb864e8981387c2d07e62acef1730aa8.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi.hip -> fmha_ck_autogen_96caa2056d99eb67ada498e287b4fae984397691.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_deterministic.hip -> fmha_ck_autogen_683e8a33fdb7053760c9c135002b0a94facbe015.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_7726be8909f631c04d4395fa4ffd03a736f447f1.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c197d1f050f42d82e6851fa286db6f81ba197f40.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask.hip -> fmha_ck_autogen_d3a23ded424200d0c6f06b1dbd0a7b7b0e7b5d9b.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_6ff4605d82507fc4bd6e96095eaee5173ea41973.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_71e3980331dc4bcec6ab6f4c345c7b5f71356979.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_7e9c7feb747241c9c7de2adf3a19933a1c4c0995.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_deterministic.hip -> fmha_ck_autogen_1a236be9da05a07d11cd28034d90cdf89941a172.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16.hip -> fmha_ck_autogen_ab0c3fe9529e24327686070731d0ac3ada76245e.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_66be70b088b20fc8de464167c35745461ddab640.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask.hip -> fmha_ck_autogen_fb4c15452f9155c5966990f09432e5eb7e28e785.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_deterministic.hip -> fmha_ck_autogen_17b9b96edda151072215502cc2b606bf1f6f0b03.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_f36aaa63ed42a578b953ebd614318d44cf44e8a3.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e578ec9e09d3b78dca6b5bf0be1538657f02f319.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_09513bff5c1da6aadf11d2e8272a422eabff21bc.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_f020134822739be6fa0bb3d98e9dec79f025324a.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_7a13d62a715fd717f0d4101f787349cb49cbe70f.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_d40569ae9dbd693c0ab3d6ba69704d31e451011b.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_2c808da5c2514806c2953bb77d5692e5d7c97aa3.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_bc79e255d25744725e2a9db9f90d5cc2b8a0e0c1.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_84dc4af43de08130a04bfa06df9799b6e9e96900.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_006c417a52a1bd7c55e45d111483d26f4480caeb.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e02a198f23c409b715761b702d7b0e6e5992701f.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_468a5f057fd5cef2df5f919f5102f47e86901e3b.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_16047b5544acef40e39932672cac6f562e200948.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_741401abfbbbdf0dd1d62df8bc3e85371ead71d6.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_9009b7d39346537aa6c4a4e46b81139f603edb60.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_75c38912947881caa14b3fc7ab7bca317e296dc3.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_1e943fcc2e64c618fc1415b3f1a0db4d70aa8494.hip +fmha_bwd_d256_bf16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_6d470f5c6fb81032fcd7974180297d4bb2a8427d.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_aa1041530f794c7b8dc4a8321ea0fcdd338fff35.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_ec9f63a538940e5ace02ae5b5ddc01f730adac4d.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_459c8fb6028991321b09a990c2188d854d940268.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_a2a715b7e9c1a576f011dfe5769c5b392e984f82.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_64c3c1e3dac623f07c2dc1b934ccb868cafcb38c.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_ccd0b777df1328bf24e070ed4cdf8615bb2199fe.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_6dd707cf48a17d31abef94215c5720419faa0a39.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_687f4aaafd1a5b9ee85aadc6fab79ad0c27a2ea2.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_aebff7e6605b273bad844b8f70ef031625bff48e.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_cc127a63d56099e08125b16939dac82f0173122b.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_7838849e57ee9cd292e588f587a8079b57becfc8.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e638053e01268a4c5883620fc6a9901951e2e01a.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_405e7efa263223148318ae96bd1929b382e994e1.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_4c69d06e3f32e3b6d28d3e54ad764b472741c193.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_64a0ca185449a49fa485892fde6af745ba758167.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_af6ccfa11add1ae49888337e84d9c446d2f67da4.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_a487f617c4b84c6a0328fedac750d41dc3dafe27.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_6e6a4475ea795935f4cbf2dc0ac156a33d754587.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_d95835bc6f000d3a3379bbc38d90e83dcaf867ee.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_2c2e75e6f659a500dd3cf2cfd65118f111342119.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_28f2e2b108a53308a0cb6c123c8d318cbc2eadb4.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_a65c43b870705c780d734f9ef063f55cf8b3b52d.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_8fc08b4f3959a2375ac03f40c4ce12d70cdc2d80.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_a673f35edd69241c6b921d6712dfd064d78ecbad.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ab877ae2a1aab04498bf2b26b3fe99d6488ef151.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_7601e6aea44b96e94fb019501be6b102c6e6a654.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_4ef35d82ceb4af2e07719c16109c6d72eaedce67.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3c64c33870ebc329921cfa3867d58b1857421f65.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_526c89b7a04758b4badbf9695b316f877b8bb053.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_b3da22d3482738a8474ae15e8e5fca9020c4e195.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_f672bf80a78885428b2c02e522426470653a7351.hip +fmha_bwd_d256_bf16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_7c19fc90e5a9c422dbf529d2def286f47dea0f50.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp.hip -> fmha_ck_autogen_76704ca28a4877a1e84022e022614709adabb280.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi.hip -> fmha_ck_autogen_0029076f83a3dc695a167beda6fe19230a2b114b.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_deterministic.hip -> fmha_ck_autogen_da29a515d14dac02066bcd4701285b9916b43cf5.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16.hip -> fmha_ck_autogen_33e7c1e5f41a451c7baff54f7238b220f1bdf8a1.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3c38bb80e9880335faaea81985ed5d0e713ecb08.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask.hip -> fmha_ck_autogen_77d0223697ed41c4c2fd8830f8df6e5620db547f.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_deterministic.hip -> fmha_ck_autogen_987a617fae00fa90a1ba60937b0312c81087c19e.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_1a6785392af35e27d6697b584cb6f17a766d3fee.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f3fd08d56f8a9be1a8dd104cdb1ac58e283b5064.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_deterministic.hip -> fmha_ck_autogen_73d4901b8ef034590314048de7223a572d61ee0f.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16.hip -> fmha_ck_autogen_0502e718337eab7d47aa65cea7d3c5f641484520.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16_deterministic.hip -> fmha_ck_autogen_618031345ea71cc17e458eb97a559b7c94d3ae43.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask.hip -> fmha_ck_autogen_14c4ebd1792c781d219bd21b691b575f64635730.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_deterministic.hip -> fmha_ck_autogen_56de9a7dfb1201b56528740e9d8a07b62710fcaf.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16.hip -> fmha_ck_autogen_cd0453a5c3828c1358360f31f5d3b7258e17fdb9.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4cb1861e31df98bdfd731efc3d335055090d83af.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps.hip -> fmha_ck_autogen_be8ec1163a01b9cd9a802d8b44669e8770c20234.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi.hip -> fmha_ck_autogen_f0cad48d9bc80d58705ea60eb2dda4baad68cedb.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_deterministic.hip -> fmha_ck_autogen_ef7cc2aa1ffd38298b52764a93cd1271b4d92f8d.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16.hip -> fmha_ck_autogen_3408103188e27b3bc55dce0c1716c0b4d32d6494.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_1bf767e7104cfc8322f26df35907fbf04b8948f3.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask.hip -> fmha_ck_autogen_9594816877815bc0294610ca24f986fdccdc7c6f.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_deterministic.hip -> fmha_ck_autogen_d9061c204d8a85c974676f4438994a0be9d69a60.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_becc2a4d7ac045365300bf8bd45fc6d3e1e1c8b1.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ddf5339054f47d9ed6cc7f9e66ab21ce3bccf3db.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_deterministic.hip -> fmha_ck_autogen_b01dc872c24db4db0c9179fc07e17f41060390de.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16.hip -> fmha_ck_autogen_84e8ae99e184013739019c93d07caddce532382b.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16_deterministic.hip -> fmha_ck_autogen_6a66604bb15f97a56847a7c968dbe32d247cbc13.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask.hip -> fmha_ck_autogen_90e5c56e92712d00092ba102a5eb5176a3e5d471.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_deterministic.hip -> fmha_ck_autogen_8352031044ef2e4a22e27ad04ab5d2c02121faee.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16.hip -> fmha_ck_autogen_7dd260849b86c46b685955cab54ba07d49b47954.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_afda8f46b5ded4c2aa9d722fec17b75004b59f7d.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk.hip -> fmha_ck_autogen_98e484adeddf3394d8d7693b808d83b64c71ee69.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi.hip -> fmha_ck_autogen_cbd571f4fe576fdb17d5f75a558cb6747087c7f2.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_deterministic.hip -> fmha_ck_autogen_378bf438642e5d863e31145ada2a0688059aa5d9.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16.hip -> fmha_ck_autogen_95530399ad7b43d8ce2c89da24c71056f2146b18.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b00e062055933388e37525df5766f3c14cd3538a.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask.hip -> fmha_ck_autogen_236b3eef02b904304348b9d35f715b639d63218f.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_deterministic.hip -> fmha_ck_autogen_069c663be0267c009be4814e9e4e7c13ec999411.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_a017be7b8bcf303b30a147f41346898acc5fab7d.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_d20d45aa85c0daa299da98c277cee826fe67bd27.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_deterministic.hip -> fmha_ck_autogen_b34c1ce348c3d9cdf6bbec9758de9d5fe94c43fc.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16.hip -> fmha_ck_autogen_1c1b0f85e085dd0769c566fb16aafe5ab5952714.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b513834918d5ea789e2db21abece7c2d3532a7e7.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask.hip -> fmha_ck_autogen_0513b2f3bd8ad51315aadb7f63737201898adca8.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_deterministic.hip -> fmha_ck_autogen_4bd4d46397a3749646b232b306688e52b8c6e584.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16.hip -> fmha_ck_autogen_f12f1f1b679cabab04218037ef370d2c7e1fe332.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_d623b36cc3f56d1001b2d3abadd8a5628fefd014.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_3f5e01b4f2ca8ea10898c39d6570bd74e85f46ed.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_a5bdc110955c05c6c6ea236a6f60266a4a6dce5e.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_70c8e45f6ea7cf5dba9eeadd0b19481d9f5defb7.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_b5371415448fffffd58bf014dac9f4876153657b.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ae4e80cb185759dd9b3eb3c67c239964b3694caa.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_096863cd93d1b105a617d0daa1d4f37d7fb6b893.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_ae8d0bdde763e617beafc0365ec4a3cd11df6c55.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_f7cf08242b3fb1c643d4149bec985b667b9d28fa.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_44c181996532676f2140fd026707135144e9d37b.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_8f6e463eedd3e65b9c79feed3cd92ad8cbc9f036.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_9638c9618dbf2af119e37596f7eb0fd3f8d72748.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_7f80d44e82e601dc48d4c8b4e710ef7265894b6c.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_85908fe6dc9c629c82d6953081b10021e64583b1.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_fecd7501265b4c4dcf015485e63e2324304f70d3.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_3b508b92f7e123b21658f6e17d624ffa87831fee.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_01e2428c5447aa9a78f79f73f31cf685c586872d.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv.hip -> fmha_ck_autogen_e088f0f7363804cf5403adef70828ab32d09a02a.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi.hip -> fmha_ck_autogen_f4900c0a5c0d03dc17d7a907ab40652d9920e756.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_deterministic.hip -> fmha_ck_autogen_cb20538073888bdb3174a8e9c32d7449072aa753.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_6a3f42d5c9ccdd3807e488b00f02bc6ab5d8d99a.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c9f1e7e478a2208c4d32e2d7e6abebdc16bcc5fe.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask.hip -> fmha_ck_autogen_8457ea5726149efb8778e6d90798b8e48288fc9a.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_37ad61bf8427a26775969f8a9166fd0bfb7446b4.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_72abb25dba0c48b380b2dabeb6ab7efaa706d180.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_1a5e18f6333ed2cce509f07cb8bd5868951d66a0.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_deterministic.hip -> fmha_ck_autogen_091cb49c1958fb4342d79f367ea93cf2b472f785.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16.hip -> fmha_ck_autogen_a93324ccf11b273ed20fd960c61df897c8890b1d.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_906fa8bf5e992ddc25815486ae9c24d8bfba7227.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask.hip -> fmha_ck_autogen_6ef5803b33d97db72eb8a8528aeb3fc956a938cc.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_deterministic.hip -> fmha_ck_autogen_a0874fc5ac87a1ec487c7722bf3b1bdaa924ee09.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16.hip -> fmha_ck_autogen_e7ae1294b6dea5c8b93c2b814fa7460c4047105b.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_238e4c1ca112afec494fbe47a85b553302c43395.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv.hip -> fmha_ck_autogen_ab09941bddfa9d61985b55f9b6bf0edec9bb89f6.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi.hip -> fmha_ck_autogen_48280c91d7cd8712fd533e246a6b0f758834abc9.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_deterministic.hip -> fmha_ck_autogen_6a95543aeed81adfb6d847f78212585a36122ae3.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_6767cce35ab784aa42ebcb75af7305bc38a8721a.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_31b807c48c472e9b1311a6037cd98e21d6706889.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_dc4d27535b9570b8f4b790470a83c1d0a9a2b6ce.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_ab56e886d53a1d88fada0f10f00b9f398dc54568.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_8adbdcd28cb2f078f89adf9aad2b3d4a0a477823.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_bcf8836c8cf932cc2748e313885003f0e11a887f.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_deterministic.hip -> fmha_ck_autogen_2af6c5be53732eb1939a2f93232af7dc011dec1a.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16.hip -> fmha_ck_autogen_da9f6e1d59132fe96709490af25bd794f267851c.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b31f56244076c501cb09b4b90975132cae4c4386.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask.hip -> fmha_ck_autogen_f9c58761c927b222112cb5cb6c9acb5d3c915785.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_deterministic.hip -> fmha_ck_autogen_041a0718891596ddac1fb0088637029233ccbe60.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16.hip -> fmha_ck_autogen_9801b25e0f132d647934deb395b62a3f70cc7c88.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_6376eb68c550b50b9aea42a7a2cc3bda186b0e40.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv.hip -> fmha_ck_autogen_810dd4e870ceda3ba9b5f0084a4b025b2e609d57.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi.hip -> fmha_ck_autogen_a821661d8280c6e9d27f2c9ce1b3c855387b5a76.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_deterministic.hip -> fmha_ck_autogen_4be4a98f150f3f9ab6f03b5fd0968c5454565c9a.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_96dee49ec6755006d67f0c30c65f50558bba69b0.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_83d580a612af85533c87aecdd7b0345c71b75980.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask.hip -> fmha_ck_autogen_451fbbdc2dcf2ec81efce34673ee6c425cc16ca2.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_c4376ac8d82db1bc25fa273a80dfbf8b71ee5e2b.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_910cb8bd09d287a1566265eb1e8894fe68d3cc81.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_5b7a4ea3bb8905a22ae97a94c354b1cbe38093bb.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_deterministic.hip -> fmha_ck_autogen_da07d8b5666423da30a95e3b2cabd3839d200981.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16.hip -> fmha_ck_autogen_5bead6be6e39ece0e5d44335083336f7f546d2f8.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_bcb6f0730fd09b4c6c60913425927dfdb8f83d82.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask.hip -> fmha_ck_autogen_ffd868d49abdb769ab82c21508d655daf54b8a99.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_deterministic.hip -> fmha_ck_autogen_d9c3e27b522320dcca5ee84fa534b03aae2bfea9.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_c323a4d1f24d59bddd20ed2f2fb6446627b0ae8b.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_fa16fa84278b489af253b52839786f94aeeac36f.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_bec9e4c0317e8d351f60258ed6611fbf365c4024.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_13d5f2ec83b3331654e37ea0b44d88cd98abaa37.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_fd614df484b263deae3b3c20adb0ce7b62eaa651.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_05e60b3ab7477f9edc8576a8bf43e3a62b8d5ef8.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_fccabea88b8e290688c1b360875d228e6fdf1624.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_486f6c7c7655c34b7b9973ff357b0813f0a3fd7c.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_3cd7a9ca49c1149d46f6b05b0fefc41ecaeb6ea1.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_5e62968de58d9df7d687d671f37d63393f189321.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_807545400aa6e70ff49a5f38ed6a218a180bd87f.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_f5803aadd93e33567aa6b23100ce4fbb6c040dd6.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_4466b6c6b2ec3acb40ac1cda432efa1e4e62d9d9.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_bbfd025488e52b97c04995c4c5faff371b77e4d6.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_b298e213f927b518c693660110f08bdd94990ef0.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_d090b771a4f9750132f549c82a88b4ab00dce5c7.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_9068ba8df8b0e977e9769f6acf6cfee6b00b9922.hip +fmha_bwd_d256_fp16_batch_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_6d17b92fab5bee7717bf9aff6a6bef7cee3816e7.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_dd10bbf37503bbc92af82bc3487989b41b20ca85.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_f0209426a8e6bfeef7d8ae7b16db791888142298.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_e89bcea4393593313d18a4aa6dcb44cd75bc828d.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_e34b7e452a4db74189334697e3a240ad68085f0e.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_615430cb65d8d540836c7f12b3367abd3c8e63d2.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_afadc4f76e237514db0bc0203102297b79730bd0.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_249e6b93baae25dff97a0bc9145a8d328ed3f317.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_c806d7803d06ef8aac1d5caac9f36aafd47653d5.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3163272d25bc2db2ffaa1fea87648b45ee68d408.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_b9baf70220079e6d4e87eb01a7259923d8a01e29.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_c5fcdea177734366d3bf283317a65cc3fffda611.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_d25ce4b3e9cc392ceafebc7fe3bcbe05aaad4bbc.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_3bb129e6dee6848043dd0e8fa812ae80fec4d014.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_7d2f87c021e0b6a27b2d7e30351fd50f06414b5f.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_f4c803838f5644ccc6f04f7c8a6233fed0b6639e.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_2c82e3c4e445e1e02f14435e4ca01a90850139a4.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_a21f3637624762547af1292e1b85e640b1d329dc.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_c9ba0a3369d4e4eaea1c902a90e6501f232dd57c.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_1914250fce818584291c69a5f058a58cfbd83df9.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_01d3b034a2d8d0b83c0aefa4faac6c3f28ce737f.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_5d707d065ae152450f9def619ddc3dddb9089e88.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_1132b11429034d96d82c82dbfdb69e460ad8a564.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_4a5dbf601de5754c03a03a1a42395dc0766fb8ac.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_5a29b93cee012c79d4364502f1d90f947c73641d.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_01e8f0df0c54ce619e5b66441b3c96a5e18b05d6.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_1d498e418ebbf33bed58b4074d1edf3d9bdd07c5.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_4d7dc0f356b630179916f8fc2041b7f1402b46df.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_292454f2d82184ab0491ea0675750c6ec55d659c.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_c538dc4f65d02776875627cbd20a9c794d70b043.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_2d1f2d1e57095f756ddd11e8e9d4f6f253e3ffa3.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_16f94f5c65c37624f5458c165daf83517d9e3c81.hip +fmha_bwd_d256_fp16_group_b16x64x256x16x256x16x32x256x256_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_2703018e71d57d3266fc35e2e18a78faa3dd52ce.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp.hip -> fmha_ck_autogen_ce5064e27ba427cb951f7e1b01328b0beb6b2b7c.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi.hip -> fmha_ck_autogen_aec87e65afa93e84d7a947c52f291c1c7360033c.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_deterministic.hip -> fmha_ck_autogen_30f0200092b0e18d57a9f5e512d565f1c0229436.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16.hip -> fmha_ck_autogen_61896aa9e4e4d7e494c1755b1e77a08e0e264f8d.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_487724686efd35731e5335efa949486c93ae26e3.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask.hip -> fmha_ck_autogen_3e61b019e1398a6a3c36143fb84b5ff22c9f4508.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_deterministic.hip -> fmha_ck_autogen_94a94d145e575747c8956ac703810582c819e2e8.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_4d3b1ae63e127b6e6afe39e354d4995afc5faeaf.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_438e3565f4c720e6c9691b0d33c1392936e2e7ae.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_deterministic.hip -> fmha_ck_autogen_d3fce1e11aee2273620e75efe4aa0390fcde9ba5.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16.hip -> fmha_ck_autogen_1f0cad6ad5b172e51c569e84cd54a19b4eb0ed05.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16_deterministic.hip -> fmha_ck_autogen_d54b3731883a5f8393d60d27487f8d017aedd3f9.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask.hip -> fmha_ck_autogen_0efdaa9266a5a464009297dc59db92504f8bf1a3.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_deterministic.hip -> fmha_ck_autogen_99f8352674bd6bbe98944a1c0a769a4fc028a623.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16.hip -> fmha_ck_autogen_a5f2f0cef657ae5e333d65ae4ab20529a43cd7de.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_9eef1b54d5d3841f3fa6b84cca6c7ad33efa2d9f.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps.hip -> fmha_ck_autogen_92ba64cdf615c1be2865f027a293cb530fc07dc6.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi.hip -> fmha_ck_autogen_931cf8d05cfa45319f4e5bb49334d35a530bffcf.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_deterministic.hip -> fmha_ck_autogen_34807a8e90bf1cd839f32fd718afa6469c35a4fa.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16.hip -> fmha_ck_autogen_1a98bcbe900f8c141136d18c114b02fffbe8bca1.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_63f121a3c8928c10a2d86b487cd13fa995da670d.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask.hip -> fmha_ck_autogen_8f607ee20c0d92b6dbd0338f139517fdcce98d0c.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_deterministic.hip -> fmha_ck_autogen_3a6b9566559ed2b1c85f2bea1c55e72c41dc47bd.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_45f4363f50af1e7ccd24751d5f5b181bf32c604f.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_22a07ecf1a59f72ec6bef3e970d7f33cf54c5f44.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_deterministic.hip -> fmha_ck_autogen_3400f0af03743dce328486f8fc805dd30bd6da31.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16.hip -> fmha_ck_autogen_9b841b7cf5da31f0c30ec42c91cc8d5bd3fedd03.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16_deterministic.hip -> fmha_ck_autogen_0e3f4cd28a4c06cc109f6a0798a77844bcc750b7.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask.hip -> fmha_ck_autogen_a103cd47156a98ad2cf2c325ea00df3f1d67fb72.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_deterministic.hip -> fmha_ck_autogen_bd37f4f7914805a97d5073f1ebf8a8b8c2648d31.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16.hip -> fmha_ck_autogen_030a759dcc92028b4c6f317fc230b98cb929e806.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_8d79fe8a600c3b4e0ec9aa510f8036ba2b608985.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk.hip -> fmha_ck_autogen_10ceed95b0a0a01f844678717c88e0426fb503fd.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi.hip -> fmha_ck_autogen_90b17d8cba28cceddb3ef907df878aeef0762d15.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_deterministic.hip -> fmha_ck_autogen_b5ac596c636df55e81293228cbc53dcbb3024e5a.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16.hip -> fmha_ck_autogen_e68a9e05debd456a9975953f7b0d510e7a0f6978.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_50f915b4d9bd18a3c25a85917392ea4a5e88b349.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask.hip -> fmha_ck_autogen_dd67d442001d2b167e70e8730abde4d4461b8569.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_deterministic.hip -> fmha_ck_autogen_4160f6b6d0869740a5a411abd80108f729f810eb.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_40357c5e9739eae136a7abf92bc38d3ac94753f8.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b4ec377c44ac18527ca6a01bc3b146706a6e1e09.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_deterministic.hip -> fmha_ck_autogen_02d88a03cd3966dd0cff550065f58c3ffecfff6c.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16.hip -> fmha_ck_autogen_093834d4d3fe76e1745e4482c6b51b550c6f3dfc.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3cf45927b6d931e31e2209685d787efa28eed8ba.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask.hip -> fmha_ck_autogen_634d530731c7ade2c7beecfd1bbbca8583032217.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_deterministic.hip -> fmha_ck_autogen_311731442b756308c0a869f21b7b8b103aa613e8.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16.hip -> fmha_ck_autogen_2ae344010d49f7f9a6caab2cb84be7f87d2d96bf.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ae239476d61f48379754b97f29d7a285cc3192de.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_6e7e1d245baabe2f6293e3d85318f9936b333500.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_f6566441ac3074578cfe45758ba0583c0da0a5ab.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_de26a187c4db06115072a5132e1166b5b03368b0.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_71dcbe9f481c92215f3b636bc0e86ce8f65e6472.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_9a20fa19d8d30654602e363806f559113218d66d.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_1e22f2d99804198c61251b4629a3f18ed3dcd42e.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_38abcbeaa4d33d3150f2b0238bb62ebbfe960980.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_d0863830fc5d43dc6d6400280e892bb7de2892d4.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_0cee6b9427c164d78994150305a47f73954a67c0.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_04caeecbc01667ec6f5599358a0a20423aa9a00b.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_33099fcfc218ffdf69edb4f2f0e46121bea9fafc.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3dba3cd44f78c950fe7ceaa5f0629dfc607b30f1.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_c2f04447e6a94c94a2315454e71d7d607a9fd0f8.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_8e2d5f979fc4fbd0991581a020a414f9c8656ae2.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_1241814f76107d74ed069ecec99a248676487eee.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_bd28203f47b6a48e9b66302cf8312f3796ca500c.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv.hip -> fmha_ck_autogen_3d289100991d4c8c362f64c8f6c4ba395c2f3495.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi.hip -> fmha_ck_autogen_7c23dde1a386436e9864c8fa5f1706c0d2fbfd0d.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_deterministic.hip -> fmha_ck_autogen_bd8bf7c572c1984ca3061062cf3c31d993f6762d.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_164a947a6c2ba83a5b1cb7074aee0bdac6c9c64e.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_9b062dd633645772e4f2caffd111af73184f7657.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask.hip -> fmha_ck_autogen_abf6c6412f9853855b74a96e862935ddef66f763.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_aebb2441e6cc1ccba4a391566e547402bcf7ced2.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_66968bbf7e210911fcb95ba90c79837230ab1ce3.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_be1e1533fc37b41838bd37edc2b6d2f2e76ae1c6.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_deterministic.hip -> fmha_ck_autogen_3a2280997eb6f1d091094fc54cecf42b7c9c3a2d.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16.hip -> fmha_ck_autogen_4b4c03c916393d6be7c5181369ebcef949eaa763.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4ff20bafbf156fe8fb80bdd84a5d2f3a4a944c1a.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask.hip -> fmha_ck_autogen_be4dd90ccb2f258029d0156cf23f940b694cf08d.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_deterministic.hip -> fmha_ck_autogen_e334e691714f0b99773c2ac515ed82de0f387065.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16.hip -> fmha_ck_autogen_62eb2f81e73d65fddce7ff43c397da6529317607.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_285e61dad8f63fb973cb2eb899c959e400622652.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv.hip -> fmha_ck_autogen_a2ef5d30a2318ae06430d17f84878800c4ca7364.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi.hip -> fmha_ck_autogen_54548ad36fb92d0963893146c8db20f53cbf0c8f.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_deterministic.hip -> fmha_ck_autogen_3967a8807c9451b09227c0f685c18aafeb062fd2.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_94f6f9dee9f0c3825d91f4d320a5280070e60ee7.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_59d366421e0b51c90fa53c366d47ed8d51b3a329.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_dd35634440edb25cb095800b882c70aaceca1dbb.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_0628931bf5cc1daa6e106cf60bb21fa1aac6b1df.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_ae4e7253ad4873576052ec0a9400597bb7975753.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f6f102a388ffb05c690a20a29cfe0b35a35eed61.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_deterministic.hip -> fmha_ck_autogen_235bf652702c2976551778b9159e09188575c63c.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16.hip -> fmha_ck_autogen_2a45129fc4995abcb8f880692f11c6186fc01641.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ff453e3bdc9752cb7b81f7cc3056325a8b9a8ad4.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask.hip -> fmha_ck_autogen_c08095341ca7e3a1debeb780c1878e351692bee2.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_deterministic.hip -> fmha_ck_autogen_5de27c4081377f59363c2bf2ea8624217566d2d3.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16.hip -> fmha_ck_autogen_8c4688cbd23727dd0ea9a36fb977b31aeae98d65.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4666db0ff7b035e54f2c0e59acedc2131b722a55.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv.hip -> fmha_ck_autogen_783ec08544591a22f59dc12f169b7327b4185a1a.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi.hip -> fmha_ck_autogen_131691f01cc7f29affb88152dd48c7a484315dcd.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_deterministic.hip -> fmha_ck_autogen_dcf815ef540060cc7ed43e1c57a28e1d080c5621.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_d7adde8780b39f1364c572a19c3bfb19417678e3.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_cf5c6c0bfaf98f6e655fc443246b81fcc730fe97.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask.hip -> fmha_ck_autogen_b18a615e66d7cd739ce35412811359a03cb23a8e.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_0fbddf533661642d84bf5a16149692d5a892182a.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_fde12cd366d6850ce26afce98e5076b695b4875b.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_df0b2bcba57e77d975ec5304fc50cbd09cddf4bb.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_deterministic.hip -> fmha_ck_autogen_cbe5a98163e878c7697e554758ebd0597c2c1760.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16.hip -> fmha_ck_autogen_3cb0cee09d633b6f70febbba63a1e090522cfb4a.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_eeb0e96b759e18cf703cfab0cda1385726f6e0a1.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask.hip -> fmha_ck_autogen_4601680af41c8738089ff377147e0547dcad114d.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_deterministic.hip -> fmha_ck_autogen_6f3d098f8bb63133924aab70d26a6ed64018c13b.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_7d08373ace7087bdaca4ce8b0bc329f553f88d77.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4ec2075f394acfb14fae7b1ef4304fd9b654ba0d.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_8a1fd28acfe85b3adac859c4bbffa4d28fe634fe.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_7bb7b63e8a4c1df4eac4d978e166867195bd6e53.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_80fb694fce7b4c3c459fca43c89c6002fbfdaef5.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_86513d6e065a44bcb0c789eed1e7e5456e800ab6.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_31222e158484773d2257f4a31e3dfbdb68336a8e.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_b20c6252863a73341b0010191fad4c834860f884.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_70cf755f1485c065222be4daab84283a9c3d0eb7.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_0c8a0bb89a6f05289c0405df5126fa0cc16252e7.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_88ac7f6cbdfca2e397bcb86af4216e87166601c7.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_db8f0bd93b352d28c5b6d78f4332026993f0bea4.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_96c5e79f54b71677124f555b0ae4bfd27248d099.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_0b532fcf26f90c82a792cde7943634f667c1d033.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_b6b17ae67adee9e56a022cd2a5514fb9c4e99920.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_fa62a97675719c2e8e9bb97361b92ff1c7b9d2ef.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_173c44dd85077e6b12dd06fdcf6b11ba349e1866.hip +fmha_bwd_d32_bf16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f861d8693f82d22e2c5b1abbcbae5f30f4433e5e.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_970073c70133ff2ee4737f803a0ac43801c47242.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_5aba1183efe205af38e79a1b2dccea5fa515d02e.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_322a86568f89a5a5a165cfffbae9ca6949f2477e.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_321500dd4c41e4d68834814a48a639f5ca36a2fb.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4f4a5d56721bb1a1332a65882132a8c5763932ec.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_44d82b58fdc3e5b7a7c20490ce7f5acce4e6ec79.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_678a4a8210a972bb2ed89d6ac754fb79438ab2da.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_085722b43cde5f37242edb071f639da7c4a0bd48.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_417b1cb14b67dc82f614831550f7deb0895bd7e4.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_7ec04763d635c5bc3e810737b5d948c59f117d5a.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_7524904ac5a2040c7ea72aef5942212f291a21bf.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_6979ef43adffdb62100270a62706fb811963925a.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_5be9ed84ad9be1627db7a66af9370679816c0897.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_ee239db5a67c23a383590a651f0d8a0be43a13c7.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_69214eb450c3b249017480efb8d092b0edad6dc3.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_0c32a2d9701e23dd930119c4ee8089042b5b0ac5.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_97246460c21bc66c0f13936d27477a9fca1c44d1.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_078b96ad691a85eebd18586db0b62b8911016d9c.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_7ee953cb24e28bcdc8f05783894b23cbf83bdf35.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_921f789d619db6f225e8e9d646e93bbc9dc1a669.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_28e4d2c757e4b8c366a2c320360e21ff0ef671a8.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_da6afccdee4107507a64323e17bf12c46da2b92a.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_38e12dad9e3bafe177ed3c27c833825813e18fc3.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_151a4425b411596c46c7032f6b83d3152a0e0cd4.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_d1d3eacc320104100bce46235fe656e5a8223c66.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_a71305f191f06cd53b7563971c706e8b71b19e2f.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_482e34930d11ff493007b1613993e01acc1af78d.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_342d29c85070f488a14b1915f948e5fd69019c99.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_b0f555b74ed36f1bef8f47880b3edc6760f27788.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_c42ab428503e8f8bfa78c8cb8d9afad9f5185118.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_cae6c7efbfc831e2bcfc8c1efa1a486c02627cbf.hip +fmha_bwd_d32_bf16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_bba10ecb79ede07324e1198a71a95ff26e9eb235.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp.hip -> fmha_ck_autogen_48ae3af78583258c4b13c11a442022e0e058bb85.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi.hip -> fmha_ck_autogen_82048cf91270631f98ac37dc488a1fb2e00ce004.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_deterministic.hip -> fmha_ck_autogen_fb4c5f8fecfbbe16e6648becb3b5ca89fa3d8a94.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16.hip -> fmha_ck_autogen_6abeb7b50ae6a1fc62535b9a1dabbde6f177a9d0.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_8a824621a50cdc3cbadc4b1f9ef18e1325385082.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask.hip -> fmha_ck_autogen_f69548d6cced86c21c09c6475237a0cb926df0ed.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_deterministic.hip -> fmha_ck_autogen_325fbcb9e503e68fafea08abf86a4951f440850f.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_964f916d3484295b5918e2e4c22c5529588a5662.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_df645b3888dc8d1df50c47c0d75822eebd3eb019.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_deterministic.hip -> fmha_ck_autogen_75a310a6eb86e3e8baac7a930c3ffbef372942b3.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16.hip -> fmha_ck_autogen_458d708d13577f2b92e6d5adfe952a87e0cf7be5.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16_deterministic.hip -> fmha_ck_autogen_15fe3e8f4add16a088fe44458353fa7c0c4f9658.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask.hip -> fmha_ck_autogen_0d0e0147a92061d32608a34e7b47bd534eb787fa.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_deterministic.hip -> fmha_ck_autogen_4e15e4f16de26068cba30ef12fc29332d45e460e.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16.hip -> fmha_ck_autogen_e2bf6805a489739abb77c13173d57723e9304afa.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_15cf7068183421b141ed5d6e7fe902d06b6492a1.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps.hip -> fmha_ck_autogen_703246f1f53a988cf252eff88bdf814bd382d3ac.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi.hip -> fmha_ck_autogen_381b29d9888365bff0f109d897b508eebfd8a61f.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_deterministic.hip -> fmha_ck_autogen_f2da112b1e07c44fc8a7f19368da203f6935049c.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16.hip -> fmha_ck_autogen_1886d4bf54b3a4a9e093360998b2059b3c03d072.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b60a4e87a7aabfe3c1ce02b408522f3ec862e3d7.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask.hip -> fmha_ck_autogen_a62a2ab489839ea1a1bfd1b24e54a3c232ed934f.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_deterministic.hip -> fmha_ck_autogen_36a0a960541bd8a2dc6741579de685b7c0a5f6d7.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_1f6bc5faf18be193212217788d476ce6fd384bfb.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_13f747525ad31e76c88774fb2208e470da9c2310.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_deterministic.hip -> fmha_ck_autogen_71b6100efe30d836dab557ea4ac54c4b9d35c6aa.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16.hip -> fmha_ck_autogen_62ab710e4acc711430745e05e036dd6a4d6bcdca.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16_deterministic.hip -> fmha_ck_autogen_7597ce4d2e5264bdeda47487d5bdb55a014c6616.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask.hip -> fmha_ck_autogen_ec7ec8d547ee9713aa3b5b667f22cdcaa8f62b2d.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_deterministic.hip -> fmha_ck_autogen_4fe530cbf6363a8f08a94728e45e88ecde299e7b.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16.hip -> fmha_ck_autogen_661ffaf653085dd7f122d603bb3ba4b001e5f3c0.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_345ea796c8d97bfe3b7c9663bf15e2e5e7696235.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk.hip -> fmha_ck_autogen_802b21f9588d72c3c3e3b9a3b269f19c484d5aa4.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi.hip -> fmha_ck_autogen_9ae866c7db36286876818bfb718ac35204fa3843.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_deterministic.hip -> fmha_ck_autogen_faf56e45b2240515e97fc1bfd552eb03b6de5094.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16.hip -> fmha_ck_autogen_fffbfcac254e33926131a71905e93f9cc0aef89e.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_919ae177b7a793fa352c4f6bb8e4175f3064d814.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask.hip -> fmha_ck_autogen_ac9382cf8bb56ffd962c99329bf67da992f8810d.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_deterministic.hip -> fmha_ck_autogen_cb1deea4f4fab0db31d46a91228601f0c272d6e6.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_144f19363ef26efd36f0436cfa9f84f181a8824c.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_0392491c5a6dfc742c2be483419a40f6a7a7ea56.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_deterministic.hip -> fmha_ck_autogen_cabb7b12cdd9b8b522af577e13232b2459dbd38d.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16.hip -> fmha_ck_autogen_caede7a18f3e3d5e24f6c70392413a2cda16ac15.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b9d00ab8373747a5c6b9d2f8dd50ceb14db4163c.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask.hip -> fmha_ck_autogen_ce909cb5f96a4884caa0d2eb8c5e6bc7fa352797.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_deterministic.hip -> fmha_ck_autogen_1037f1bc50c4a65dac09ba56b701256b701c4322.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16.hip -> fmha_ck_autogen_aafe891dad43815e635f81225705ff944f990d75.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_7afd1a756247b15b078d15a39e350a07c22982da.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_3e839660557dee9d5bcda9b56940ce23236c5f6d.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_fd26e43ca652e6f58ff48c356165aa4349833b55.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_b3486244e0b7d6dbcaa1951e8b8883ce441c3f99.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_90da0d469cca5c8481504148468460c85a15c559.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_714c5369aa848021e020d874289e3ae4e0f74d77.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_cc54b107e1b557ea36b5cbaf7fe3dfce05415c86.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_00042c36bc588e60a7c8a9ba297a8a25d8ac0660.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_65794d9c185b21f59274ac5d4db10a7abc0be968.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_faf686067fa433cea5e95dd523846dc881eff635.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_39d3071347a0c98f3221104036f477aa13bffa4d.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_e76879f8ff4796f48ad87ff8003f4f6e6adca9a0.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4377ac04be3a6cbdbfbe57612a469412812fb5b5.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_c1b76bc7a17f573c0d52c07ae9ff4302662ae61f.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_d713fe25dc90b3511fc259cebf463376dcb55d84.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_89a3327da9a3411ff1cddc67eb647083cd947a92.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_6a7eb3d86aa385f9ecffbc5ba10489e56856f918.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv.hip -> fmha_ck_autogen_4d65e58c9f147498ed04dd51fe1393770603a6d3.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi.hip -> fmha_ck_autogen_a5c0109313de1f6245d2a80f8539485b849e9d55.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_deterministic.hip -> fmha_ck_autogen_e73a776ae4ba68c23acab1a5a6381684051738ab.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_a225c4f1f3c7b271957768bb9235131c67afb48a.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c9530e20038eb40c49bc8b045be0cf4e7e6b4eac.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask.hip -> fmha_ck_autogen_f51f1a11f778d99a00aa5959a3e58a41fcbfb1e3.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_de36bc309877917a18fd21acb30563c7e2f233c1.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_4b45948f2795293e72530b02669c4f549608ea7f.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_05f794c7023cbb7e35f1fd1ae45bd2377bfbc520.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_deterministic.hip -> fmha_ck_autogen_82f1d7e1a93bf2fa80c409e6827ea88af56c44f0.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16.hip -> fmha_ck_autogen_4baf664bfdf070362bcc91af77d1bc406f744351.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_80efc341089a50ed5669b3c86f6ddd9b124d1442.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask.hip -> fmha_ck_autogen_e465193d97d43237c22c04478ca5833011d8dc8b.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_deterministic.hip -> fmha_ck_autogen_915b75db795dbef037b14b003ee073665fe35d3e.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16.hip -> fmha_ck_autogen_fb5bb49928ce5515d7b297d5eadd4ec70a22d60b.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_349241529745bf138552f49d9a93db418663ad65.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv.hip -> fmha_ck_autogen_c4de1bc135191f3c2aff740f4c6bb7e98da42f84.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi.hip -> fmha_ck_autogen_4ce03571f1d2779bdeaf0a6a2d617e236d191c11.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_deterministic.hip -> fmha_ck_autogen_ea077e68dbc1bed2dd20a5f4dd35e0cad6330ee4.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_c56aa150611b0d4800470c1493dc907082a5c23f.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_7e9519dd0d0f940fd5efd61bd32df7528ba7e3fc.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_47548aa042c69bb9c59a8bf706b44028aaa41830.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_3dff884e176ec7cff86d17c6afe1ddaa4dd6007d.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_2081430c92864c29bb9f409e7c27caee1de00749.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3d1cea88a2277b87d405025ba256272a1720f88d.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_deterministic.hip -> fmha_ck_autogen_a55c7dd576e5b1061c059e5e99aeedf4389e2d25.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16.hip -> fmha_ck_autogen_8c074afcf33e3f3534ac3577484237fcfd2ca48e.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e618fb4e529104fc90069c8779ce5463460bd516.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask.hip -> fmha_ck_autogen_44462715ed5f192532760d6f4c66ff9d4e20e254.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_deterministic.hip -> fmha_ck_autogen_e1d85ad2c9d197f501267fe0804e6985802fbd18.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16.hip -> fmha_ck_autogen_78663faeb0425f45e8a0da0f7b1a5ddbee5e07e7.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_15e8e1ab8c63db96843054bb7a98d708ae6a9c44.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv.hip -> fmha_ck_autogen_629e0b97b3fece7c12504f4c8f1860d611b57269.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi.hip -> fmha_ck_autogen_29c9e5384809b21f39e78bb2e43af345a9a21d19.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_deterministic.hip -> fmha_ck_autogen_14f77aeeafe4b28f314fde5ebccfd2a554872781.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_09d76cca48b71dbcc9bd96734787209fee4c9a74.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a4980becb0d3149fee575bad1fc3b463d08aabf5.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask.hip -> fmha_ck_autogen_55bf8444c1c26b91fd490c7216f4d0f8aa0a1f1a.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_e4d9a2396ceccdadab24602f30e9070901a76dc7.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_14fea611f3c253aebf726af3e5fdb7e63e18e13a.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_587fc33d02b1932235b8d152e57559060211d591.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_deterministic.hip -> fmha_ck_autogen_680e81c3700f130df142c9a37a368944ca548721.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16.hip -> fmha_ck_autogen_62048a8ae1c0096f3372b0114c15edbe813425fd.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_9b4dcde1ae3446b825dea739d4295c1d1ec5c4be.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask.hip -> fmha_ck_autogen_ede81dbc4cb208ef6e684c76ba1eb451d37fe10c.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_deterministic.hip -> fmha_ck_autogen_59901147b7188212b8d8feea15831a11425fe4b3.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_c9ad71883a19b522486706d3705700c012a6fc19.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ad82071cc074fd30437f6158b5eb2c6df1f8c587.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_bd3daa5f99b4522d932334924347353ce2854821.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_b72a804bb3c99830653d41ac0bd49943c801b89a.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_572e68bd619e118292768f0925ccf92cbfa68415.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_ee1a43f2210a8d1e5623411c95c33424cee5e747.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a93a03b33305b33055273711ab31a5b8d8298d5d.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_c3cfaf0d53869c373f6d0ec821b008dbb819141a.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_b2af5f5b5ee3ae964824a3e9c7bbeb5bb39c557c.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_56964a17f902257aca9d08c736516a2c67d9a0e9.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f9824fb32933b27501ae8a7f43f460a2dda6a814.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_4118e3ab290263ed2576feaf22a1944bf2ddcb7a.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_ce5ad502dd40353312d561e9f40aa478c16ef5b1.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_6d07bf9c05e41dcf2416e05dab4bdde17158db76.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_b1c5d55d47d6038e9162d32ac968ff58c0942938.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_30c8e4d5c761fda50e010da779e8e4730051d403.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_193699a5daa14ca2def07489e0b563149bc403f8.hip +fmha_bwd_d32_fp16_batch_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c0342686e4efd26413c6719782ed13603479c4e0.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_fb79e1f9231692d736dbada062ed6821f34927bf.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_5f3c3bed2b584ea2031debf9f953f5f8f7012171.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_319df310195191895005b30151da8c1afab6c82f.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_6af23d1460abfe875e71f7911697c42fef0f41c5.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_cde0582e1aef74f9209de638b553ec0671476258.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_5052b2318dbb78b1a82ef03666a35a623f44481b.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_2543da478310245e19e6c6a0d9ed7ad99540b3bc.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_20f7ea0aabd069362ba4bbd66623cea5b6e1a6bd.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_da74887afedbd67928fe4d596709f9ff92530611.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_155c3549d067464d186a99b8205317cc000d4898.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_089a347aef8a920e3b59d5ffe71fc5bfe002609c.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b0dd965d5d9080ed5c6a04b7eea9890f3a264f20.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_2db33b5442d2e0948762b1f2147a321a9d6907be.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_1cc459e57bfed5ec7f40ea4a4dd9f72f3ad7a709.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_02ff94e3c787a7b06ffc90c25777fa74f225e32c.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_86309c036d96367939ccc3e8922595ac35a3e179.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_8b92990df507e82f96eeb7aa3ec00c01437566fb.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_26835ba70606c769e56d19dbfe74061361aa855e.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_dc1a7f9b1afeba6690fdc0d0d1755ea89c805573.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_815918206483d2ae04a45aa67d69dfb986587214.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e1c1a31a1d8556cbe0b6ea76faacc78855108539.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_54b6e18b10d529eb6b32d7c19c59eaefc7184376.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_a622fa57764ec746e02f6d4bd4846b48c722b807.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_3b5b3c218e4a7b459e54080e24c5b730221eac02.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_85fdde4b25e2fc8cbdd46c2850c19eac8d9af8f6.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_a4b7f10440331a8a88ff93ba253217c2832bcf9e.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_4e9a933b916285d9580a76df543cfafc88a536cb.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_fe8b8c3525fe86a20a2d6c69585f3e36c16caabd.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_7d12e9cb599d24631c082e3cf65d2c58b6d4d44f.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_8e812705ae3e452810794fa7caceef2ef6066dfb.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_7cdc419d4248dfdeeab1f0980aec35fa134e52e0.hip +fmha_bwd_d32_fp16_group_b32x128x32x32x32x32x64x32x32_r1x4x1_r4x1x1_r2x2x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a046e888e3836b0bd3c49fec8e1872e880798f0c.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp.hip -> fmha_ck_autogen_8278845045d68027dcf3bf867ecde2fb12ec51d3.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi.hip -> fmha_ck_autogen_18a4d71b31c451a50df7996e3db864bc3c3882ed.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_deterministic.hip -> fmha_ck_autogen_5c36fc744dfb0d985c9113175e76c7ec1c935054.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16.hip -> fmha_ck_autogen_b779cc0b0380e1e6a2b51fc6216fdd72215b882b.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_19af6a7f9e5020e8d0f0ca0f6258001f6ce592c1.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask.hip -> fmha_ck_autogen_459ea3713aef9b916e1b38a882a45012930924d3.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_deterministic.hip -> fmha_ck_autogen_977137b371df841993c8d0584be7d83aca6add78.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_7497eca4d1a18306b406b367653622a8d64095bf.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_9bf235679af1ca03a6e601b4cf6cd0416d1c9091.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_deterministic.hip -> fmha_ck_autogen_7177f939ac3dae8749cbf4232dcf04d2cf63b48f.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16.hip -> fmha_ck_autogen_1847fef2c06ea581b0ab31af1cb0556c572696ad.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b5bccc85f74f54a2ceb17fe3040b04fe306c53f9.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask.hip -> fmha_ck_autogen_f7aa9c39b06e55bf4bc9f9a2a0fb075c9d4e69ce.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_deterministic.hip -> fmha_ck_autogen_a78fecb9725ceb4bcf2aa037d43bc43efeb1c3fd.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16.hip -> fmha_ck_autogen_f93bf815b520a9d9e17b43bf9d7fb870751b6225.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b24f91dec2029b25d0d96962528410df55a468ed.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps.hip -> fmha_ck_autogen_00a2adbe938d458d51ca5fc4020667a215b672a4.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi.hip -> fmha_ck_autogen_036887daf6cc092e7422a17882488e59cecfb643.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_deterministic.hip -> fmha_ck_autogen_1a96f0ac76f117e66eba97cb990c2350561ec2ab.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16.hip -> fmha_ck_autogen_0c3b2ec99fa7b09c7f78dcc3142a661d686044ac.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4e760de14b71a41882ec4a2c7362565af36d1a5d.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask.hip -> fmha_ck_autogen_94aa519eb57e5797125728492d9330f5c0f0670a.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_deterministic.hip -> fmha_ck_autogen_6bad2ed9f91bc1efd89ea66cd5c775fa140cf931.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_9b73c92a13757877f34bd8a13c6fb29b60999020.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_1dc6e599144a093203fd7f92ac6d3c2cd7180d49.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_deterministic.hip -> fmha_ck_autogen_7e6129eead18d13a4a6cb9550384fddabc7a2a16.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16.hip -> fmha_ck_autogen_04f39b453505f68a5091f68b1c3de48369d1e7ea.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c5b440ca9a5196ee1e72c878c87d96934e9273c8.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask.hip -> fmha_ck_autogen_cb4576e8ea5d59d7663f3760009a00a19e1b0667.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_deterministic.hip -> fmha_ck_autogen_44690e48f30657b0fcfa26fb3b9af3ef76e792e3.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16.hip -> fmha_ck_autogen_b872f9e6ebe330cc1818ea82b53acec79a2f672c.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_0fcb7492feb79e27e0bda73e57ef7dab410e2bb6.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk.hip -> fmha_ck_autogen_7a242e5953f44316b6a4f6587ec26283ed6cbcae.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi.hip -> fmha_ck_autogen_2184fba2eec5899bb40d49d4508196e6be1ec1b1.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_deterministic.hip -> fmha_ck_autogen_06b74acd9abfbd1c4ec2f4c718eeb92a0bca7bab.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16.hip -> fmha_ck_autogen_ce5c161b725becf059fb4439c668edd454ac77d1.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_addb6a14043c5a4df0f5042b3770b40c4e90795c.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask.hip -> fmha_ck_autogen_7ddd621da88c57798db1e689b93b692b6519ff96.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_deterministic.hip -> fmha_ck_autogen_b0544a38dfdf4d81dc95894387845f48435e299a.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_11ff174ff2175e9ec22ac3a0fa59dd7713b79643.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a3f9c236d24b30bc9c3fad90cfd6eb00da835de2.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_deterministic.hip -> fmha_ck_autogen_515128c6978449b33ce0c35b02a9e9aaad65ef7a.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16.hip -> fmha_ck_autogen_0b3153af7bcdba33115a0d31f121fd76be2ffbcc.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_d3a2edf232786d458e2125f8dfeda8847f842afa.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask.hip -> fmha_ck_autogen_a7f7553a7d2f6d42fe695cdc64423c85223af440.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_deterministic.hip -> fmha_ck_autogen_a9b50c6ebb27986ce5b378d8c39315eb9cb91dea.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16.hip -> fmha_ck_autogen_2f55a23a0f24ff7062a4c286944f25d2db3e20a4.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_0be8cf70c6be969ecfca675782c860b5b75ac089.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_1e9130607a2d24cb0662a47e9cf12c6602143838.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_cee81ab2e2678816c7b516d2d4c50e8cb5874c68.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_c5fef330a975002ed15670e8e7b26a10376d3cb7.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_0c9bd38b8f9009d932ec49204fdea39a52885246.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_82c932e6eaaf44861c794539d9caf8b50192fc44.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_4568af1b2f104664fd05d21ad789aed39ecfa42b.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_d9c23b7f8fcc4e4f4c81f5f00cfd345b98df2e0f.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_de7eb562a7eff31d589e12945d80233aac202ae2.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a92b43d374642df991edef1f6036dc898bf77cf8.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_014c209d5cfc6b965bfd78c64bf132c0154e32be.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_1687ddf65ce4ed2997583e20fee9f201e86633b3.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_fc5841a729099340d608e31023acbeaeade3e886.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_9cc3ef3d3b36f52089548e9dce522b0448e2c26a.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_8efb5fc2ace6839eac741c5e6616665845f43566.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_ef5421703cbfa63a58ec02701e245d479a1fbfc1.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b50e6df20a2426abd3d2ff2262a37c009196024c.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv.hip -> fmha_ck_autogen_a094599fb5caf5e7aba728cd4713a8d0c6368a46.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi.hip -> fmha_ck_autogen_21e235e31d6955393ac8e825bd69ead70687b7c8.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_deterministic.hip -> fmha_ck_autogen_289071756e7d0582eb61ce6483fa3c988d2e10b5.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_1899e28aff2fb168cdc3af7132dd7fd09c2e1ced.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_2e30f50071113dc4ab59468d568ac9deb06b0342.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask.hip -> fmha_ck_autogen_bdab172627718278a71a93e3737ef08ad9259a4f.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_77200e875e0ef160b311c7de450c137772312d0d.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_cb1b91c16e0255fe7a0a85638b98d94634e143a9.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_877e33463b3bf1853c6d2d2009af8d27bf88abbe.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_deterministic.hip -> fmha_ck_autogen_92e53359c69bbe4d7405d45261a8a62008eb7d06.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16.hip -> fmha_ck_autogen_7764814a0de7702f0b7b5ce9dede6440603f4853.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_05dfe927fd64a564c5fad537fb7c41ee9c94c2c0.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask.hip -> fmha_ck_autogen_78f7e2a2c08cd87702793f91b6935cbe4c22be55.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_deterministic.hip -> fmha_ck_autogen_d4605b2ad3e3753c5f255678abc1690b949c5abc.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16.hip -> fmha_ck_autogen_037c6c80fcec3eb8b0bef50ad6af6d27bf5447f5.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_fd9cd1305633b62b68fb8474ce021f639f8492e7.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv.hip -> fmha_ck_autogen_d2f4b869ff23874b6bde0aab68c419108b7e69f4.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi.hip -> fmha_ck_autogen_6ff58a5186d69efd6062f3717bd315394ea6592b.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_deterministic.hip -> fmha_ck_autogen_8021fa266c77e6b5bd1af2a9c22c686e5a6eac78.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_de5359f0fba3da9dfed06ddbea8fe2a33a9cf40c.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_fe72cdd69944d2d765478d4aed13066a02b76f6d.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_6a7b6781ffff9a42beebb4d73f0d15461ddd4479.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_28f7634d29bef11fd466b452a46b0612f38c949b.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_66f651d3415562206c1049b172261fddba01ea6c.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_32438250078ba2a47345ec4955dafb4e4de78a25.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_deterministic.hip -> fmha_ck_autogen_5ea53f7c6370845fa94aa9b395c52fd1900b62de.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16.hip -> fmha_ck_autogen_d50ac8e8a03f8e7ec2c6e993dd39f09f465dab57.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e2b629c37cf94134693ce455b8c88b72a39df7fe.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask.hip -> fmha_ck_autogen_157b89d8d625b8244b5cceaa4d3e5fc5a09c8989.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_deterministic.hip -> fmha_ck_autogen_5789f267d34c9961ced63ad07ffea2c6d2911415.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16.hip -> fmha_ck_autogen_38010c9bf7341588f071f889b7a0b4dcc4e7a14c.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3d55cb42b0096a8ae338ce100f86e378aa1a04c9.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv.hip -> fmha_ck_autogen_11e7df31541c3aa919e9825ad7dc4432f9a03c0c.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi.hip -> fmha_ck_autogen_d7145383e39dec0e346b5094401acf85ef3c2075.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_deterministic.hip -> fmha_ck_autogen_04c363e11d202c6d2f4bb753661c5a2043edc0ad.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_8fb33fc20f2e85e915f1b1529ae87981dfcaf86d.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_97851d5ecbf02f8af623988b1a39c0b91e51533a.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask.hip -> fmha_ck_autogen_9163ae070075f26926a86d39e15c27e6edb1f1cf.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_9ab73ea77ec20ea3bfaf995dacf93a6960ecdca0.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_21828c7d3f5574690f12f841c27f025206e6165b.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_dc08afbff5def8bcb4e823657ce01f57c9dc77c9.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_deterministic.hip -> fmha_ck_autogen_875b08ca602fe48840c72cd61798acb98540fcd6.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16.hip -> fmha_ck_autogen_216806a4598c885e517e664fc8280c59ec3cbf11.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f90410c26d7649e21e2ae5e32e7af89d84d2ea70.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask.hip -> fmha_ck_autogen_a3339150d8bf9d073827738527f6cbe15b854607.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_deterministic.hip -> fmha_ck_autogen_7a0ab620e6d62259a559e329460e46e6e3f7c3f9.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_7a2e032f6500fbc5468183415b6dd1d3e43f0bee.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_71a2d046629a4b65c90d0e18d061c4984062f844.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_04ffca078cfab8bc6c4ccd1cc8994a1bb4a88ea7.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_836a308c2d2afd6e0dfbfda61984b631c4ccffc6.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_62ba7a5a0f3a714eb5f9f2af20f7bfbc82a30350.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_e9b04e6d5527ba0b8089ba8bdd264e2d5759338b.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ce5b5932f6df9a194ceb0d69220fba9596528eec.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_cb3d5273945c5d40cc05c2660af2df1fb7a15f3c.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_5ace1c9b00f160a17355d4583d49c47887ac33c8.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_054fda16133a0d25077967b05425f9128e1fe1a5.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_7adf69b51f0a8cc9ae7e250e60df38758230fe4f.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_cd757a8bbeabd16a44d149ab188430f6d79ddcaf.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_a5fa94bb32a80e81886b711ebfcf2df5f5405866.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a02f152e9184af0b3d77082d8bdf519dbbfceb2d.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_cf73e1fc0015094861ca0c1c81bacdbe0c5b8f37.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_a9df9ac4ee78e5f4d5bd0567e58a7090907c61e1.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_92121fd448b4640a17e1a7fe73bb7b58714c0afb.hip +fmha_bwd_d64_bf16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_2c9756060ac0e73dbcfc58a9222a78f0283cd029.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_354121d3bad1d448bd413718fa096f54faa12e95.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_d4c9f975891087e6eed6393629b41155deafc509.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_9bcc791049e3ff9ebc1a9085d2d20efcc2f99b71.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_79d0b8053ddf99a4d4447656d733c2da026b3a7c.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_6f8788c537cbf6833c58a6ca15c0a36de33c9fbd.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_5fa19223cf296d7fd10e15e2571e63c84a80fbb1.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_4dde56efe17f4fd36a11cc959320a5e43f1dc232.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_fabdc143c29d5ca50ab1e96a814bda6d05b0d5d2.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c977735a36c325706bd19a12df66ed0839b032b1.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_7872c45ba170f2782c4b5b75cfc78ac79a4cf157.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_7c4710e8f4e27fae4ae079f1667c3a1879cb6da8.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_09e50367b62bb09071e28b44235a7c112645a706.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_1be43f8b629e7039f57b95866d5777273377470d.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_d0de618ff3ea9f67b90f2227fb7fcc74ea34183d.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_865eb90b1a2d64acc0f6fbe1d807c501fd4be3cd.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_bb35c86443cc9ea38c06ebc0656306483c95ef67.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_ec171210efd217c07d357fcf42e5372ad7e9abab.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_bd80a1774d8b7d8bee4e8663392b97cda11dcbf5.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_b19f05f6848403480ba41d37cdbf44ccca1b1f8d.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_e639a1e84faa98477b05df71d363b9ff0f9b2760.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a388a284f45f711d82a6ed87036d87cef1872eb1.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_26ea90eb5a527434c1740933a1d2dd863eccf14c.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_e16edb824cecf459a8ec51b8dc74b1e06369aceb.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_843e7888cba5f463d19fcb71aaaab25dc3d2c09d.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_4f6243c6850c0a2d2b7bf1476e12f95f187257b6.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_89617bdea526d12d6a33ed42b9b0018c0b173722.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_2b4050988e5790a28dbe10b4c20e14f10f6cf85c.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_44cc95831c347212021c0bab7b43acd7daabce42.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_ece60111633db08f765b3c7cd5cd768cbd030255.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_79a7dce707954e765d97cb22e57d9bd6168860d9.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_761bde840c0c8149b24a8f6f264e963c4e9e8ceb.hip +fmha_bwd_d64_bf16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_609616f72bf16a060fa50091ac139ddc06bf9d88.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp.hip -> fmha_ck_autogen_ca1992a2634cd6674076611be54197c715ad8271.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi.hip -> fmha_ck_autogen_2f0247e301a7b076b6ec8a778c3b47e330638963.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_deterministic.hip -> fmha_ck_autogen_55b14cf2998a61611d1de2594e926fcdc378999c.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16.hip -> fmha_ck_autogen_21411df58165946bf02942b597d94de7dd856987.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b3063d06723ac70c5f8802ab49c5c35e1debf56e.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask.hip -> fmha_ck_autogen_4052ca6a3ec02f6559e4bbf1edde42ad2d127c26.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_deterministic.hip -> fmha_ck_autogen_d41cd6b60a97e7071518cbd1a63abb8b910df024.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_e75d492ac3a6ab75648056bcf26250a4aa929cfd.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_474fe2d739eca8c93fdcb2c105d4154cee6ca1c1.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_deterministic.hip -> fmha_ck_autogen_2c0bda0feaade2b554d648d72f219ac9c389bf09.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16.hip -> fmha_ck_autogen_2122c973581930ab7a4ebc90b3bf1cdaa229a87f.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a20c91b2f11bb7e5058ca7935b0bda4f5558a9dc.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask.hip -> fmha_ck_autogen_9990e6ad243a48b84304b5cad0c663c0802aedfd.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_deterministic.hip -> fmha_ck_autogen_7264e378e1ea1d4dd97f6949d66f3492883b663e.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16.hip -> fmha_ck_autogen_7878e2a4d3b96a552e03d1ffc33debfd50c9f7f1.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_fc1eb85a00017efdc610e4259d2abe935b85304f.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps.hip -> fmha_ck_autogen_cbf3e4d4d4837a0cb33b78c4f2767b1d93da0850.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi.hip -> fmha_ck_autogen_5f8925f929a5b26f3544ca31938aa75b3c59d34d.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_deterministic.hip -> fmha_ck_autogen_8004763f674dfb3f14b66dfdeb2a046e413ce2cb.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16.hip -> fmha_ck_autogen_0878b9aa31429d23a93cd953cc6a2fc5f43d0d3a.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b5ba2e73df35f6e0f7317303823fde92a42b1a35.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask.hip -> fmha_ck_autogen_d34fcb56caa8f80404789fba0ffac447483a4d84.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_deterministic.hip -> fmha_ck_autogen_cb1a0ce432c27f4cfa51731c3ef181bf60c8a727.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_efb9e7d9af47cdf79f15f674f8976c05f08b0ce8.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_357f7e626135cc9176a295f3d1f336a7c3852688.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_deterministic.hip -> fmha_ck_autogen_22c142d869ef940ca876c93033ad53b576ed34f2.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16.hip -> fmha_ck_autogen_1621507cf219fe608715d4e5bb6e5764022e2d61.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_dropout_wg16_deterministic.hip -> fmha_ck_autogen_a25e2aed617e1ff31f93ae7e054313ee0dceee97.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask.hip -> fmha_ck_autogen_7ec038393ec329a894aee9bbac078a40f57a4684.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_deterministic.hip -> fmha_ck_autogen_15dc02ea7e0908cf0bd48034f5a49debfaa36219.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16.hip -> fmha_ck_autogen_758b211174da0f398b2a093e7389905b4f9c4060.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_ps_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_548b347672451e8391388a400d016803f4c4cf8d.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk.hip -> fmha_ck_autogen_ae7899b1ef159ecbf01f27014601eb79b31b49b3.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi.hip -> fmha_ck_autogen_b04f14f829eff73afaa57a875f74ebd1e6860979.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_deterministic.hip -> fmha_ck_autogen_2ad492377add5c8f6d0d2dbf9ee9e4338bbd9f1f.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16.hip -> fmha_ck_autogen_7f6ccdb3c2d595fffd05bc5e6417b157276547fb.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_69cbe8eca7e3510f5caa7f13419cfbefbf031754.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask.hip -> fmha_ck_autogen_8bd7b8c63a51c8639b3cf27ad09d41ae47c480d3.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_deterministic.hip -> fmha_ck_autogen_f21596e8c608a795ff971aea8e199db9e72b65d7.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_1da23de9604b5d98fe02529075bad995954c12ca.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_49d4c005d723cdab9fbc307933c1257d114b539e.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_deterministic.hip -> fmha_ck_autogen_e2c9f955f227430c6224ebc347649386be7f01eb.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16.hip -> fmha_ck_autogen_290c484c2a366258941ee0051e139ea716a9de2f.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_84cca7528c7d1bf49ba79625733ff0ae7522c096.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask.hip -> fmha_ck_autogen_f3d0166931e4406873d8f552a5d5b61fde2391a3.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_deterministic.hip -> fmha_ck_autogen_8046f566fa7188c92568b277354e8b06ad382544.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16.hip -> fmha_ck_autogen_12d60c8abecb3bc9b84b0ea7851628ab17d8b0b3.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_psk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_f50fa4ea674a590d0a817367ad9915a5fce20c51.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_0836d5dfc0f939ab9a4064b403339373caf35b56.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_de6683d175affaa5ff261ab8503f64172d8eba8b.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_beb9afccc15de7dfcb2e7d898abc0d61201de73e.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_e6e0ec1db1ea308e226f675e68e29b839e41b252.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_7c3d8ef4da515960bf40eb1feb04d21950ad5ae5.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_fcbe827108d252b2f5847fa8e132c9c3e56a90a0.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_7993fc08ac5c6ce7a2eceb1227f4e3718dc4cf5f.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_06ae52ef937cc27c544e32025ea0dadb7fad982d.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_876a418fbe6183d0392b7a7d9986d067e323e2b9.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_b03ab68e33844f97aa58d463e00037bc11c50da0.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_8c7970957024de050748d3e31cef434f582d968b.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_add29e3e9828911a117dccaa5650e77805730d14.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_0e007c36231ccdae12f102eacca1f74b0711b9c6.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_765940baaaa2ae6ade43ef4c94a220eaa63702b0.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_c7af2bbfac25de2853be344b9f636226c1c0112d.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_b2f91e937b427ecc932c0cb0c90b2c2378db0be6.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv.hip -> fmha_ck_autogen_8da8285bd6182355e3164cdc5a983375cdf0a61d.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi.hip -> fmha_ck_autogen_a3ff8445ba691807caadd9f26e7eb90851875280.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_deterministic.hip -> fmha_ck_autogen_9c4fc7cda4b560040cec93f63021b529aa1ee3fd.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_4018b1fcee808b6cccd131418b6ae9e8bf900d8f.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_88d52c5f70abb525b9c8aa8fc1cb3997c33ed67c.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask.hip -> fmha_ck_autogen_99e2f290b962f1617b0a9d4fd6d55c43e4439d6f.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_25938733446b6c0dcd159719f08d04a9aa467967.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_76f884e9ca116ee47b446efe9fc770c178a858d5.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_42e2326066c91452335eac05f25a6311376bd9e5.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_deterministic.hip -> fmha_ck_autogen_24643917fc970c043d1c80d8d4b17ec92deeb8a1.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16.hip -> fmha_ck_autogen_d937609afa8e21a761dad6b01ff3f26346e450fc.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_59beb9cb4e161f9dcff79080149076488d436301.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask.hip -> fmha_ck_autogen_fd3558b4c7a667dbc365c4c2ceda646975408f51.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_deterministic.hip -> fmha_ck_autogen_dda8d021381083bc48b7fb1840729254dd8e5137.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16.hip -> fmha_ck_autogen_ed37ba962e0288e2840eb0925d016b5a7e3b3164.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_5467aea26852aa9a9e3dae76b906005ddf6fbae1.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv.hip -> fmha_ck_autogen_76be322fc072ca19baa82707e260c6eba936ae19.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi.hip -> fmha_ck_autogen_c921a4790f982d48bcaf950123c699647afb739b.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_deterministic.hip -> fmha_ck_autogen_76674fc182dfa6329c73a354aa3adf458429444a.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_54402a22ceee3b665a3f24edb98b8398c35c6f5a.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ada016be2bd0e377fbe01fa7adb9bbb8febce100.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_6db86621d626722434f2ae9b7b8ab435a8dd8827.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_162b0dfbe3f615b1d164290799b2457437a0044b.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_628b28f65f19e7d1b22fb3b85b7cf3d09cd54ebc.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_031b12f9fd94e01aaff2c0da4f35f346822087e4.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_deterministic.hip -> fmha_ck_autogen_b9a742ceeb6736a2c8f9439d0b05e10d3e0c5c6f.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16.hip -> fmha_ck_autogen_afccf699f593c828e11efc053b144044e45b32d6.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_fba36678d5047ded97ee7a7ba9feb9569afdb6ea.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask.hip -> fmha_ck_autogen_14baaaf1e90a075ab802c6e7d97c4b1605c8bd72.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_deterministic.hip -> fmha_ck_autogen_0237c76137df14fb808ade8bd6837045f2aaa5c9.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16.hip -> fmha_ck_autogen_c2a2856bf9a81544a30d535a13554e3a8107c476.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c2940fd05efd52bdf8a3f9aa4b78bde9b5809b34.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv.hip -> fmha_ck_autogen_d049a1b8f4c1c6d37973ce38593efda1de8ce0cd.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi.hip -> fmha_ck_autogen_f4b87f983a5e84582efa1663f84da76cf60b5f6f.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_deterministic.hip -> fmha_ck_autogen_4db2e63cfebcf84043f79be0321708cd159c62b9.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_f25b87c435bc5d7d85d738f3fdf68947d79f5a77.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_540bd57333c6839ccf5cf2e928edb996bc60c371.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask.hip -> fmha_ck_autogen_9583148fd684a7e6a312127e023798278415bd27.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_bf9cdf86a7944cd690b0fcbbaec235863acd10bb.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_2da2b905c4ce32234c2af62328adae6b1f9217a8.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c4015f0d0a7a5173810f6f17c00065e03fc61a89.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_deterministic.hip -> fmha_ck_autogen_d773df9ccfc1ace90fe3afb5c00976deabedf6f8.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16.hip -> fmha_ck_autogen_d137b7b6e04e1caf43a62bd6788a75361cfa98f6.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_adaef10ff2c5d89530310bdf1d53a194f06a94ef.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask.hip -> fmha_ck_autogen_1be746990a2032f0363ad9f9112cc994983f4706.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_deterministic.hip -> fmha_ck_autogen_55bd9c4f1b7a0621c67f3e964d946ce22fb2fc80.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_4dc87b7d385e7b092e4706c464217b004fd8a6a4.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_pskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_91695dea4171747fb3cc6d910459f800608d07c1.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_c137c03bf161b2ec6a9a046fa49d7bbf80ae47b8.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_83080406598df6bd3102db70a554e496e29db96a.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_03a71615a088e972c998f9c7cb44566c268c5124.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_6214f820b39a8ba81e547a78ed19a909ac13221c.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_3e2557f206fd81d82a3b9d59113105040beb891f.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_461737a13e24009bf1a5a4b780175043a9f2e33e.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_83f6a1837a65df12b7c55d25ca28cc939c2a6328.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_c59a22c6efd8bb8815887325aa0b739e260cc754.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_6049c01db99fce654e9351e711b113cf7424550a.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_c9f28230817c9d9805c41dfcd4e834fe302e1df1.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_7728d5bec7941c9b6d5632bee8d67ed92b9c03ec.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_28f1ef32c4384ec26f3dc5e3af6a74fc8cebae92.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_594929c433b049a8cf949ff476309a8faf5c25fb.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_8441910c34830ad2459fb85c2c14af02da718fdc.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_230861e81e5acc523fa680534eed757b7b4a4e1d.hip +fmha_bwd_d64_fp16_batch_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_c112c01d201c366bdd7acccf2e1b18b00f671153.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk.hip -> fmha_ck_autogen_6b638314efcc4f16aa4a6e58e6caf2fda1711519.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi.hip -> fmha_ck_autogen_c8f6461673882d636772ae4d26e78eabcb568f31.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_deterministic.hip -> fmha_ck_autogen_f93bc23b8a4f1e0fc5c5756c4e1c835bf59dea09.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16.hip -> fmha_ck_autogen_4356b3a2ff49f72b91a6b9c215df285f2798ad47.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_e1cc934ba7baab1a2eb062df1e4ee5066e9ffbc3.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask.hip -> fmha_ck_autogen_137fa6780d9e6bde10aec10a875c039fdbbc652e.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_deterministic.hip -> fmha_ck_autogen_06ba94794a14f0f0022af6f5f3c16e1e16959d4c.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_4b1eaca3c37a82d19f8dc91f06764170069ca3af.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_91c916e14198f6d18dc89915e379b01070434e91.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_deterministic.hip -> fmha_ck_autogen_8e816fcad5e9ecfca94a6491eb2274bcc41e558b.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16.hip -> fmha_ck_autogen_5fc66c5b53f83bf1e023e81e9d51f0285b3ae731.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_dropout_wg16_deterministic.hip -> fmha_ck_autogen_2d9c659ba43bb907fd4e3e36a50958288bafd1a3.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask.hip -> fmha_ck_autogen_07ff04fcc273e469737512893ea3fb5876ac131d.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_deterministic.hip -> fmha_ck_autogen_22632f996eb63fbe4bc5748c5897b775087446a0.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16.hip -> fmha_ck_autogen_f5f1797f6b672a55476348571ce17645c8a62869.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_iglp_pssk_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_eee408cf9456ff977aa7d12345e9b2f1e60639f1.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv.hip -> fmha_ck_autogen_303b7b04496e4db7c1ba2436485dc7c8a4c88448.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi.hip -> fmha_ck_autogen_fcb0b08e29b2e1bf181fceceb9dc416e54f52b00.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_deterministic.hip -> fmha_ck_autogen_d06ba4c996570ddab77b6ff1e2a0101b638543eb.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16.hip -> fmha_ck_autogen_fc5ebf0f2200f37ccc0849e0c3745f6e2f00111d.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_dropout_wg16_deterministic.hip -> fmha_ck_autogen_2caba3ab83239e474412fcf89fe0fbef97e51bf1.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_dc184767d723f4995791848cdc68bd948408204f.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_deterministic.hip -> fmha_ck_autogen_c53e295b68e807774ed31bb914e4bc59312a77d7.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16.hip -> fmha_ck_autogen_db0d0cf55d90b3f3c9eecada1db93c420f34b1ae.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_alibi_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_d1c25cfc437d8bd803860e39a45b2f3b9fa48393.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_deterministic.hip -> fmha_ck_autogen_01ca79005067e20e4eed5a72ff9187cde702cd1c.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16.hip -> fmha_ck_autogen_a5e5cae764142683b70d3344cf07dd1edb7d69e2.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_dropout_wg16_deterministic.hip -> fmha_ck_autogen_ca920c3239bb5796b1ab2fc75177eb3b820aa784.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask.hip -> fmha_ck_autogen_806f9ab9baf631df1d3a8d801e4cf93a102526cf.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_deterministic.hip -> fmha_ck_autogen_4b30f472f00bec9da0564ddc40e07112b5f9a117.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16.hip -> fmha_ck_autogen_dc039d422a57c159ea4dbcc867d766ff1b356a07.hip +fmha_bwd_d64_fp16_group_b32x128x64x32x64x32x32x64x64_r1x4x1_r4x1x1_r1x4x1_w16x16x32_w16x16x16_o1_kr_ktr_vr_psskddv_mask_dropout_wg16_deterministic.hip -> fmha_ck_autogen_5b55946ff3c15a44b9c741e9f6bbbcb5bd4c8577.hip +fmha_bwd_dot_do_o_d128_bf16_batch_o2.hip -> fmha_ck_autogen_658552954505a2092662071401e135e84956c4c0.hip +fmha_bwd_dot_do_o_d128_bf16_batch_o2_pdv.hip -> fmha_ck_autogen_53bd60bd2afee49b30a583c32a45ae9f2076db08.hip +fmha_bwd_dot_do_o_d128_bf16_batch_o2_ps.hip -> fmha_ck_autogen_8e675919a6c7758cbbeecb83b7ac6c62f95cdb46.hip +fmha_bwd_dot_do_o_d128_bf16_batch_o2_psdv.hip -> fmha_ck_autogen_2d06f77a4054ca615d96636c0e2eba2a89850142.hip +fmha_bwd_dot_do_o_d128_bf16_group_o2_ps.hip -> fmha_ck_autogen_187963e1969301abfa61d06afc97faea2bb4efb1.hip +fmha_bwd_dot_do_o_d128_bf16_group_o2_psdv.hip -> fmha_ck_autogen_e7153f9a9b0b7c54ddf2debbe297efcffbb4fcfa.hip +fmha_bwd_dot_do_o_d128_fp16_batch_o2.hip -> fmha_ck_autogen_3c3b7e4b8c1efe59f79a15512716fce2282a79a7.hip +fmha_bwd_dot_do_o_d128_fp16_batch_o2_pdv.hip -> fmha_ck_autogen_19cd9f7b08cec83736605af63d9fcaf463a1aea4.hip +fmha_bwd_dot_do_o_d128_fp16_batch_o2_ps.hip -> fmha_ck_autogen_b4588379eaa268d79fe8f8e4457b009f204a5fb7.hip +fmha_bwd_dot_do_o_d128_fp16_batch_o2_psdv.hip -> fmha_ck_autogen_23c9b46da8774462de8c24e14b12df3ed596eb57.hip +fmha_bwd_dot_do_o_d128_fp16_group_o2_ps.hip -> fmha_ck_autogen_5b413bdc825ae863d53dab548f2145dc0de8fd37.hip +fmha_bwd_dot_do_o_d128_fp16_group_o2_psdv.hip -> fmha_ck_autogen_58a7ab44bbd9fbc97c7805860d5f6ac81d6ae468.hip +fmha_bwd_dot_do_o_d256_bf16_batch_o2.hip -> fmha_ck_autogen_50f887556a3540609649744957651ca667b91774.hip +fmha_bwd_dot_do_o_d256_bf16_batch_o2_pdv.hip -> fmha_ck_autogen_eac5952f46f4f2bf06257b00661774eeed48a323.hip +fmha_bwd_dot_do_o_d256_bf16_batch_o2_ps.hip -> fmha_ck_autogen_efaa0cb33c71cb8ca7b83dd0e7a6c7b01f6b50a9.hip +fmha_bwd_dot_do_o_d256_bf16_batch_o2_psdv.hip -> fmha_ck_autogen_71e5fb3544dafa9da03fd2de4bb9bd0718f6009f.hip +fmha_bwd_dot_do_o_d256_bf16_group_o2_ps.hip -> fmha_ck_autogen_3fad30ff0739ab5dede67a96e859f8c474c245f8.hip +fmha_bwd_dot_do_o_d256_bf16_group_o2_psdv.hip -> fmha_ck_autogen_4bef4d120e71bfcfe61d67aa44d24ceb907c2b9e.hip +fmha_bwd_dot_do_o_d256_fp16_batch_o2.hip -> fmha_ck_autogen_7d0f767c17385eb7d756cbe8ed444d7cef72dea5.hip +fmha_bwd_dot_do_o_d256_fp16_batch_o2_pdv.hip -> fmha_ck_autogen_4b68e4d00295b294320b94bc777d7d34609127e0.hip +fmha_bwd_dot_do_o_d256_fp16_batch_o2_ps.hip -> fmha_ck_autogen_33746071156e9ad46f403a539dc237e0a44122a7.hip +fmha_bwd_dot_do_o_d256_fp16_batch_o2_psdv.hip -> fmha_ck_autogen_3d45624dc6e33c477c73a155500b015b6c010de8.hip +fmha_bwd_dot_do_o_d256_fp16_group_o2_ps.hip -> fmha_ck_autogen_8250f27341241086515d833aa53ae873d4ece3fa.hip +fmha_bwd_dot_do_o_d256_fp16_group_o2_psdv.hip -> fmha_ck_autogen_8793dc3217e154b65ebba065aa10ab4dc2374ae8.hip +fmha_bwd_dot_do_o_d32_bf16_batch_o2.hip -> fmha_ck_autogen_1a11dd5ebb989503a1c182684e7f247e2f8cd9c2.hip +fmha_bwd_dot_do_o_d32_bf16_batch_o2_pdv.hip -> fmha_ck_autogen_e16075c3a5fcfe63ba12e854bb1fed6873f014ab.hip +fmha_bwd_dot_do_o_d32_bf16_batch_o2_ps.hip -> fmha_ck_autogen_937801fbb43fb6797f0425f08d13926b74d87c4a.hip +fmha_bwd_dot_do_o_d32_bf16_batch_o2_psdv.hip -> fmha_ck_autogen_fecffa403b3631b1957e1a9a06f18fdb3b4eee5f.hip +fmha_bwd_dot_do_o_d32_bf16_group_o2_ps.hip -> fmha_ck_autogen_5ba578c0e7abf1127dd0370f06d7278656c93ab9.hip +fmha_bwd_dot_do_o_d32_bf16_group_o2_psdv.hip -> fmha_ck_autogen_345a939a2491166dc520e9a2b9de7e43671e0c2b.hip +fmha_bwd_dot_do_o_d32_fp16_batch_o2.hip -> fmha_ck_autogen_7393267865f1c2b0aa1a09a586f54cec98eea4ae.hip +fmha_bwd_dot_do_o_d32_fp16_batch_o2_pdv.hip -> fmha_ck_autogen_93b885d6869400b0dc2ef1b2c2636ddfd21cde31.hip +fmha_bwd_dot_do_o_d32_fp16_batch_o2_ps.hip -> fmha_ck_autogen_38f8a89468cf9c8606cf12a930db062a83cd0ea0.hip +fmha_bwd_dot_do_o_d32_fp16_batch_o2_psdv.hip -> fmha_ck_autogen_f974b12e83e214c30995a25631d37df1478927af.hip +fmha_bwd_dot_do_o_d32_fp16_group_o2_ps.hip -> fmha_ck_autogen_2bb6da1095bd8669c0e48b5cd808cf0dcefa2674.hip +fmha_bwd_dot_do_o_d32_fp16_group_o2_psdv.hip -> fmha_ck_autogen_0e0a2370f2a320484d8f9f21e3197425c2dbe9ad.hip +fmha_bwd_dot_do_o_d64_bf16_batch_o2.hip -> fmha_ck_autogen_a9f00f270680de81df7737e848e0408cb070e68b.hip +fmha_bwd_dot_do_o_d64_bf16_batch_o2_pdv.hip -> fmha_ck_autogen_61220f6dca850a5b5ccf1f619a267c40c37efeca.hip +fmha_bwd_dot_do_o_d64_bf16_batch_o2_ps.hip -> fmha_ck_autogen_b192c55f002d8540d5f965cc4df0c2e33f4b9ff9.hip +fmha_bwd_dot_do_o_d64_bf16_batch_o2_psdv.hip -> fmha_ck_autogen_295a523f815eb822d66162d4feb75fe0bc50b648.hip +fmha_bwd_dot_do_o_d64_bf16_group_o2_ps.hip -> fmha_ck_autogen_292b4f995d622826af5d1f2bffa7ba68467c841a.hip +fmha_bwd_dot_do_o_d64_bf16_group_o2_psdv.hip -> fmha_ck_autogen_5e840be0741afa4d41fd4789c8300223fdc63ddc.hip +fmha_bwd_dot_do_o_d64_fp16_batch_o2.hip -> fmha_ck_autogen_0e1dbc9c433ce8ec33ace9e62550261d613db582.hip +fmha_bwd_dot_do_o_d64_fp16_batch_o2_pdv.hip -> fmha_ck_autogen_6eebd0c2fbfc85f938b10535855c388971129a28.hip +fmha_bwd_dot_do_o_d64_fp16_batch_o2_ps.hip -> fmha_ck_autogen_0bc7910aac798f0555e9e505ad7f177c9fbbd92c.hip +fmha_bwd_dot_do_o_d64_fp16_batch_o2_psdv.hip -> fmha_ck_autogen_18b92b4e249195ac3e0c74d246585a4c9e0992fd.hip +fmha_bwd_dot_do_o_d64_fp16_group_o2_ps.hip -> fmha_ck_autogen_278639d44a4a8372a627a7c31e9527c8faa26f97.hip +fmha_bwd_dot_do_o_d64_fp16_group_o2_psdv.hip -> fmha_ck_autogen_8e938d0e3ad30db201880642e57758285b2ec4cb.hip +fmha_fwd_api.hip -> fmha_ck_autogen_1ca3f45d0be2d1119cccd0af042a3e8adeda2ed7.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv.hip -> fmha_ck_autogen_f727911254904ce4341e4ff5f8bafc430b8cfbbf.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi.hip -> fmha_ck_autogen_54208a6e8c5263e38f9ffcb062564ab61d2785ff.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_dropout.hip -> fmha_ck_autogen_1d3ef3d5ded0dfe2a0bafb52ea8f841658db35fd.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse.hip -> fmha_ck_autogen_f15c41ddb04ec7f80235bb3db19198dd6b699713.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse_dropout.hip -> fmha_ck_autogen_a5c4dc0d70c547dbbfb661e879ba7f9adfafc2ea.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_d7290cc4c3036c9205e689cbcc60e7d16b97a7d6.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_dropout.hip -> fmha_ck_autogen_0b2647b5982405a48e8c8888552a4b89386ccdd9.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse.hip -> fmha_ck_autogen_eb278488b2cca114adca5e4614d86f92447f937a.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_29fe68ba10b3480dddc9866c51ca8b5efe962cc3.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_dropout.hip -> fmha_ck_autogen_92992be6252f2afdc368bd4baec4b8a55ae0abf8.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_lse.hip -> fmha_ck_autogen_501dcf3213efd214cc2ce8c9ba0027f991d241b4.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_lse_dropout.hip -> fmha_ck_autogen_aa6d13b09f85ee62bb5018608812181fb43afc86.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask.hip -> fmha_ck_autogen_d0f63cafbeb445408c884727b473667fb479675e.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_dropout.hip -> fmha_ck_autogen_7596c14b8fee751d03f42ca48ea4f66e87fc2e2f.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_lse.hip -> fmha_ck_autogen_c2b719893a4d8a1e71857966d399f06c0a41749c.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_lse_dropout.hip -> fmha_ck_autogen_071751b1012b90f7b57f8591cd06ae1fd27d9cd3.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_d00f65bc99ca08eba66564d34f72f2769bff9491.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_2273457ac3be01cc1595a015a5f598f8290c77e4.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_63c411351ec59bdbed2590c599f9eddf7807b371.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_042a156e9eb935555ab14a84461959b466c2fb5b.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_eab6cdc59bf216f7045f0cf5f221bb91ec415cd2.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_d703eea8075cacec4d41fee7dc4734f593ee79e8.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_2f32f2d658f1f69840fbad511ce8a3851c859d52.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_e6973d75297bd2c3432a7c88e8a9ee1c9ae693bf.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_854c8003a508ed3f8cbe6967c4ae2635a491c721.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_ceb9544e2a0caae2c9e3dd8bbd2c509e8dca1379.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_e83c604d1b8260958becd1c7c209745ff9151715.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_3b26eafe76cca8e74e819220b6de1f4279d48e43.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_d5e82799f4452e148c3e02acd6526cf30757eb52.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_5435b4651a90e331fcdcf224282457e3dc038a30.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_1573e3d855d28c54af612ab950b081302891d56d.hip +fmha_fwd_d128_bf16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_4e47f8fa40332c6ed12d9971e0b539049a871c34.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_b285e2f1970b78e18002464eeda63798229bbc3a.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_75f21e38ad01fade35b1db40adabd75eb602410c.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_81f6c575c3fa2ccc7e65022f1ba65c8cfc16541e.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_45b9871c220c0065d74bffeed4021d0304a9625c.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_f028af9e5e3c25800dde938e991aaab4fc1d64aa.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_7fa76fc1b066a15b08dc6c24a7cf33a58b4cb6cb.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_157768cd725813f8111d265cfdfea7f42034e5e9.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_541874a7633e5713720b9d084b6d1c6715a51a17.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_6f88527a2cdb5adf51407f4661a254bb32d7de23.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_a55b47aafc4340e69e300ac61a7601a5c14513b7.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_20d5c3c86398f6ce55abc90db3e362dbf9f457f2.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_8cf1007430da272174d3476d042f398627e83512.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_e7d37e7ee96c392fa24c02a9143438a3a7d05741.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_dc91797c1474a368e9cb056b50b4629d7736c3cb.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_3cce3baac1e3ca03af0c3f4ee4d0158ad1031e9f.hip +fmha_fwd_d128_bf16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_9d6759d8855c4c6289f1f241a1628cf0406c1b64.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv.hip -> fmha_ck_autogen_b38a1d3cffae01332a3a9d9472ff1b2c443e82af.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi.hip -> fmha_ck_autogen_2cf351fc2c2da4a8e1760a3affc9a5947c6b3bda.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_dropout.hip -> fmha_ck_autogen_bafbef3f13d429ec3e9f4672218998d5669d79f2.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse.hip -> fmha_ck_autogen_3f34433b784d1e405ade3378918641372a30bf6b.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse_dropout.hip -> fmha_ck_autogen_5fb062527121e627871b3f1b2a94b96c42e51205.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_5732094f5917e9164ee0f973ac6ec47245a69101.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_dropout.hip -> fmha_ck_autogen_688aaa193f332ed13e017e78ec07a7c80e45f6c5.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse.hip -> fmha_ck_autogen_1cbf88db44aa5f884438288a325270d29c7a04b6.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_2660282ad39ef034fecbdb74acedfb48620b7dfd.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_dropout.hip -> fmha_ck_autogen_a59423c095db052603d77073d409534bceef425f.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_lse.hip -> fmha_ck_autogen_3fcc6893456a559c7d22714116022fc69b372266.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_lse_dropout.hip -> fmha_ck_autogen_c7568e11e44ce70924d27e683190422cfae5c31d.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask.hip -> fmha_ck_autogen_f79def2b4edf6d18f6ef1d6b141f9e0435441f6a.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_dropout.hip -> fmha_ck_autogen_32652a27e8605cef59c8341813b68e7513be23c5.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_lse.hip -> fmha_ck_autogen_b20e314642cf565e4f32bceffdb5c0e653ab627b.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_lse_dropout.hip -> fmha_ck_autogen_a74b0e7dd816ad08eec5a1bba6e227afee9813ec.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_a968df29f5ae1463706b7981b3bde55918e1aa65.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_b5248f443a12d96815c04409a00102923c717023.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_291a8bdf9d63b112e7fe5fa7e8835a6789cb8ecf.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_6d5aad18f59e47a3fa3278c7ef1a6372830c33d5.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_c063318cb851ccaa923be12d34c84d839bc64bb8.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_a5a7833f4597bb03a3e845d5580d677e97421040.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_2d9a04b7f41dd6f0db017157a44790f35c626e2d.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_98f5efcd500ce6b9ffc14bc9877e0ba457539925.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_135ea67de101135ed5fe04f5cab1ec1d7b3714bb.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_951343832a5bfd060c8d12da0d8a090f070a717d.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_f24d42e820adc1a26a428d59df7ffdd7f8580176.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_4dbdd9c3f496a27bde68cf86374999ff2dd53505.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_3be7cea6df8e6dd56194e1172f28943667f1c4ef.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_483eaea4096c8f5bee16a64860432f0634a253d8.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_103186dbad604763008e0204a1ea90baecef8877.hip +fmha_fwd_d128_fp16_batch_shb_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_367e58867c46d96c9bbaa96eaaa9f93595c9e099.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_311104394c8bef8d4ecff35c1409221e723a5a8a.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_976cf509d9c2bf86ba6ee5ded544fa8e6717f590.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_24410fd9a4150c33186a2a365d06d8f6ea621c20.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_b493c99888d82cd2852bfb101f99a2e6a27665b8.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_1fda1c96568eab89a8f6498f8bb23c1223cdc7b0.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_053981d9e7af2ebc0f91e61ac5e25cbe68c95bd8.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_3110540b50e95e99a5cccebe47d9d3a83093c2fb.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_1fcdcb750f382fc7828a9886585f50efbe5be735.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_c3d0eaf9399c863d672e8c08d123739bab837d4b.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_9d69d441f48f9ea346dd8e00376a9a708da3ad87.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_3992d5df4ba2e999caf6889a852db4e1ba078e65.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_f30316cfe49323638f71ba688dd8ff9b2266b335.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_797750ac0b18b48f56ceb4640256e9bd3a36621a.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_942439e4f5644a3a4630481bc7d98834b29b6e1c.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_fac99c3c82b77946f6844699d2333cd532a78a26.hip +fmha_fwd_d128_fp16_group_hbs_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_98f9a4f4d85f292b78123599a2e1798f12aa545b.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr.hip -> fmha_ck_autogen_ea591185b1c5f521023e250a26f742984255b241.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi.hip -> fmha_ck_autogen_48300e0aeabe337785d4c7b41796ce65df6cc42a.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_dropout.hip -> fmha_ck_autogen_e514c6b4bc75d95a150104a17972abae77cb47ed.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_lse.hip -> fmha_ck_autogen_a64b4cf3f6706e4b4e0af4402e2263b9a1585f9b.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_lse_dropout.hip -> fmha_ck_autogen_e389d0e4442cd8304081892ddc75043e68a6398c.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_mask.hip -> fmha_ck_autogen_ab43f4a56c166dad0113f51b337a083f4df7cdb6.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_mask_dropout.hip -> fmha_ck_autogen_d4645b713821371161a9925dec8a3d6c157ba1aa.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_mask_lse.hip -> fmha_ck_autogen_0b90a0186d8b8004e3f19886c7992c8e04d0e066.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_d34d6cdcd81a456125ab5e0875466c6334d8e5c8.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_dropout.hip -> fmha_ck_autogen_d0b09e8513646fbb2a007544a63ec9e2b04dc4c2.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_lse.hip -> fmha_ck_autogen_ca3d98ff43fbb80ceb82fc22ab039bee898969b0.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_lse_dropout.hip -> fmha_ck_autogen_7ea9c37d92e344f3cc58cd4d1d00f19167e3623e.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_mask.hip -> fmha_ck_autogen_db85839ee8d464c5a81b8dad9839f5e0f4b467a8.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_mask_dropout.hip -> fmha_ck_autogen_32527660fa7aeb9a951a9f2fc3c53989bd141c48.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_mask_lse.hip -> fmha_ck_autogen_528db08068589c6e4c096054d26a2e5be63285b6.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_mask_lse_dropout.hip -> fmha_ck_autogen_d600779c17b7b21c18e1308e6d765fe02a7945d3.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv.hip -> fmha_ck_autogen_445e28a8a51cd435130ded2abc9fc606e522c713.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi.hip -> fmha_ck_autogen_8a980749c6b2a18c80426dd189e5506334343ca4.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_da822ea727fb3543e445e4000f7e6ebb946d6a3b.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_f525b59df454ccf53da6cb201e0aa8d09f52a2ad.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_0a2b116fd5065109aae46ee547e4f49ad0e9d6e1.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_366662dccf2f650bcd8123c49006c759cd4c0ef6.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_816c48e129a0235cb3a19124ddb28cce286fb368.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_356f83cb96d0313abcdb24955edd4264df72aed7.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_0e661b5f30566d1f159f060c264849c7ae4772f1.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_dropout.hip -> fmha_ck_autogen_61a9e92183ba87924e73ff0b5e25bd12d6038e69.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_lse.hip -> fmha_ck_autogen_e502730dea6987e2c038446c448aa08bdcc23113.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_f851da732f397624717160f89271514bc334b59b.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask.hip -> fmha_ck_autogen_fd345632e0cae0d549ba79626a08b1885711deb6.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_937c48d0b7096ad6c8bc445f13f2c8c1934695ab.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_a2482a64659c838f3da55f56e3cbbee1dbfe6722.hip +fmha_fwd_d256_bf16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_f34fdb8294257d951dcc9c4fa7ecf1192568b91b.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv.hip -> fmha_ck_autogen_0aafb881e34a3794970a1282af740b3f19c138b1.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi.hip -> fmha_ck_autogen_c250ea59ab6e1ee39cce15cbd3f181047cdee31a.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_4ce671f5defd76ca08614a7a1f184c36c0f1e2ab.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_b9627f9c8d0088df0364a64643f2b5dcd951f2bb.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_a6461d72fb6ba50e81de3f661528c96dcfdc3f3c.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_aa82d20635e592edbf00439294835f6f39ad54a3.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_146eb8c40e3146e06936f3141b2c4d92a578ddec.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_c28de8f96c8315877031a2d56261e95fee6aef44.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_39422621a00ff79b2f5ec0dafb957c77693537b3.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_dropout.hip -> fmha_ck_autogen_a0a556c9358ddd6db719458c81d2d6d822a895da.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_lse.hip -> fmha_ck_autogen_c2fcced07cc194a8050bc7b2f791453b3f5b2064.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_210ef512b7862837f54acbc3b21e135a192647a3.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask.hip -> fmha_ck_autogen_bef3bd014a918feddadc98eed92a7734f9bcd890.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_ae1ab1f4bbe86bb9bbc22e4774648076c321136f.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_52a8a323414448c50571a334f29bc0a38919b61d.hip +fmha_fwd_d256_bf16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_204a573ce6b7d2f90aede543939315561cc43177.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr.hip -> fmha_ck_autogen_d8901a63986cc28ef24cab012b32114851a8c1ec.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi.hip -> fmha_ck_autogen_12d5c8a4988efe60ef7943ecd73e18a28a736583.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_dropout.hip -> fmha_ck_autogen_e5b65fc519ea7cfcd19f7eddbc3acad6842ff558.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_lse.hip -> fmha_ck_autogen_743176ecb1f0bc800c870861585edf56f88d7739.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_lse_dropout.hip -> fmha_ck_autogen_6b0ef67ce0f178aa2863c4909f5bdd7f766c9b2f.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_mask.hip -> fmha_ck_autogen_ef40f0acf1885096efb840ec5600ec421c4db331.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_mask_dropout.hip -> fmha_ck_autogen_523e5bf45ec5008aa3aba4773e68a78e122b2fe7.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_mask_lse.hip -> fmha_ck_autogen_55cda610c235987e13232e828f8d86fa88030560.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_566b4782793c6526bfce7362efbf6bf069928b2b.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_dropout.hip -> fmha_ck_autogen_cfec97bdfb6fa95e057eaf5a8138853e1c0884f2.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_lse.hip -> fmha_ck_autogen_6905ba47078abd7a5b6a51eb93b26095517e7f70.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_lse_dropout.hip -> fmha_ck_autogen_8840e8899b4e632714632450bcef001c6070f955.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_mask.hip -> fmha_ck_autogen_d867098db97b3f26e71a151c63b74260bfab21f8.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_mask_dropout.hip -> fmha_ck_autogen_bc238fd2095b26a167b41cdec8280182330b7b25.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_mask_lse.hip -> fmha_ck_autogen_b737410b404a51043fc3bd503c0b107c297e4c9f.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_mask_lse_dropout.hip -> fmha_ck_autogen_b4a5715b550f67b8870ba66e1e6282a26cc1dbf3.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv.hip -> fmha_ck_autogen_12207f4b6e7fac27d6c16493a5373f448a2aaae8.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi.hip -> fmha_ck_autogen_7d5667b27f15a06d4040354fba3601d48bb9c045.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_2695783ae8f0034692efd6563f789ef03fd0f4f3.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_60801d21c14796c08377349ec86a6c800af497b7.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_159ee1f1b44d1a8fbaead65d8449413bb616d15e.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_9f0517550c7a23882b95de451e8099ea2186b4ce.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_80f51f0e178c33e6196df1d2e47bd38bf5391cc8.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_489e7be0f85656d012a6451b65f6c1d2613b187d.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_e7de729aa50c10d8101ef504138c3769e3286753.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_dropout.hip -> fmha_ck_autogen_25b3225da1e1842f83592971a1f62a0fe30aa9d3.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_lse.hip -> fmha_ck_autogen_ce4714e4f33340859c106a3129993e22652262e2.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_bc4e0f0496a34d2fb43c80ce0162ad4183f29064.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask.hip -> fmha_ck_autogen_a9d2be18e2d53a5144f97dfdebb225fcb6d611d3.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_4ab5d6e8fbfd92e9f7e47bda5cfbb0d4162a6319.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_ac1ccde31b47e0e56ee0daab6403fed7895208c7.hip +fmha_fwd_d256_fp16_batch_shb_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_5cd03e29403ad53d6d52e5e81182ea6ff5aff2be.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv.hip -> fmha_ck_autogen_2005aca3520b171bb82d10ad70fef44f28c19776.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi.hip -> fmha_ck_autogen_c402e84359b2037a29efd1d6ce7213ba7605ab25.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_95061acc6650fc7b79fa1fe5b2b1e083555eec2c.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_1fd9fa7c2e13d0bad5fddb2b5a316bbc09d397ea.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_dd9494d9ac35eba6794a4f9120d2db9932596ef8.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_82d7f61e6313930f063758b61102e7a43b118beb.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_2b50073f6dfeb7ea77d5dce288a1d2f08f8f6362.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_5fa7fafd4227918e0c7f0c6ca3b2bd673cd07279.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_522a2a9435103ed405dc1500d31652f1d431a49d.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_dropout.hip -> fmha_ck_autogen_4b7393d55600c9892558248f4131fc06a6cf3309.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_lse.hip -> fmha_ck_autogen_d66c30148a6fa816937f2f095802264d3dfa0273.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_8f7166d4bb0c1c9b9999ba16a1adbf09ebfdb6f1.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask.hip -> fmha_ck_autogen_80cf0997573f4bcfbaaf75e40f519580a7495a17.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_48d7d145f96aa8958a9208d0c8887742a8c834fd.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_bb111b7acc269f8d5e70915d3efde4c425aa5f5c.hip +fmha_fwd_d256_fp16_group_hbs_b128x128x32x256x32x256_r4x1x1_r4x1x1_w32x32x16_qr_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_48435e5dd23e49e19dd313f9891ffec800ce74c2.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv.hip -> fmha_ck_autogen_e2762543d3380185e304f84749a70db1b8d3dd8c.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi.hip -> fmha_ck_autogen_5093976cb7b32a8bd28ce92fc13af00a3e21f737.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_dropout.hip -> fmha_ck_autogen_efc6a7b25710f0626c3af534111b161e1459d2e1.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse.hip -> fmha_ck_autogen_a8a744edfa3a19d1493611df5bd0d4d59b707d43.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse_dropout.hip -> fmha_ck_autogen_e95e3908479965856843317c8b0c42a6961dfd23.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_2b5317b6cde327a842170ebff20c2b03d81379ff.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_dropout.hip -> fmha_ck_autogen_99ae680eed89ea93a3a94586bd5a68dbc5439f37.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse.hip -> fmha_ck_autogen_1edaf9d4270d2ac61c299320e06ba73f44730364.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_0a4e76d89b175e1d9fd2e9fb908d5fce1ebb945d.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_dropout.hip -> fmha_ck_autogen_fba47fa8d9b5375bc408af68b67345ab9dba2eb8.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_lse.hip -> fmha_ck_autogen_830e3532f27b391585d5de90f3bdf97992b67651.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_lse_dropout.hip -> fmha_ck_autogen_66a020f728df204ff51e37d2ddc21afb0aad5e7b.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_mask.hip -> fmha_ck_autogen_07c3fc96d2bebe546dce6ebf46e5c7a519959599.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_mask_dropout.hip -> fmha_ck_autogen_74d5f2aef029f2103bb419cc982cae99fd1a9253.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_mask_lse.hip -> fmha_ck_autogen_58a784fb478ff5b3f1e2da9765a3a777efda92e3.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_mask_lse_dropout.hip -> fmha_ck_autogen_0766e7aa4b263a811408b285213e47176ee2bdaf.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_bbe23201fbebed25781f249e5c77c31e0e7f9ddb.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_7a890b126da2d8cfbf84f048b779cac2dd56b509.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_58679919fcd292a2a69543de0db94e2985c9d364.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_84fc5e94f89d6a9287cf64662a372784511468dd.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_3bed3aaf24c73073c604a3b23bb4b0358b8e3490.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_dc5ba6d73f331c76e696953606c5b347b6a46f3f.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_b4f12f10d7b968e0d8e7c23f36d3a360de74a905.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_41b68458076e6cb129d3ec793e95b91430a0c8a1.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_56ffe9e21362afe9c3a407c09d5de186954931a6.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_2ba934408c75da5479cc41f96b98ea7d333635ea.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_bd6aa39d0ae3c87d011610cdb5e2e317f337c454.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_aece14f7a220222eb4ce6783ec2b9fce6fde94b8.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_6e240106c771ebea461fc2a87b6da68e510aba70.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_232f61bf31dbb5de5d7039d5ff2338068a759b68.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_e0e48d7edfe9513f24ad9fae68cac3aa940b17dd.hip +fmha_fwd_d32_bf16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_bc897852a4ca992961843144f4ec4f8b86dd5e9d.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_f1246d1013d954a9316f4432c986d3be9459c548.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_6a4b6226b355bf35d4d07aaef1828091f03ad2ec.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_2b49a9b0801a06dd89c7f7182d7590b515df1592.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_50e7b11019fc2299d70869253877319b03388244.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_7f9bb3486fee7b7c9e24300b8a4e4ce88a11bfc0.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_6785dcec0197fdbb50124ab06efa627f1a2c0567.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_f87991cb7787a29d3ce4711b4ce04c5fb6a14ca9.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_bc4425e30a0b17e8b31726817e8d3177b5c51934.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_54940ce53998becf9bddf56df7d19894a7658168.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_ebb241b947a0adfc8e50c5d71765c14af24593ae.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_3d3f3eb2f5eb1f3287879604892b1c230df85f1d.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_7b9a3bf1a9b37e0bd9bae6249609e5994dc0dba1.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_14221590b90c48d3cf259fb4e834ccfaf7f3209b.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_445cd8fa559588f4264ce6192f2de3e3065365ea.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_7a902ed4ae3cc6558c73b730ff3949778007a230.hip +fmha_fwd_d32_bf16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_0682150e93f547e00f13cd8984779bf49b91e50c.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv.hip -> fmha_ck_autogen_d86e4dcbe9c4cac8f7c8c5d97ce384ae0cbdbfbc.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi.hip -> fmha_ck_autogen_1df893ee660d37fba7eaca452ae65b3e45a73087.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_dropout.hip -> fmha_ck_autogen_92739f4464512feee083b875e11e11eee4f5b448.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse.hip -> fmha_ck_autogen_65910c8b7a30acc731948ab58467fdbe4fe32f6d.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse_dropout.hip -> fmha_ck_autogen_df5b1c6758d4b8540158299dd0362297083084c2.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_ec7fc24902b1ebd8f2bf8088b0ecf6de8be8362d.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_dropout.hip -> fmha_ck_autogen_9e51083e13aa4dfa8c969f8f916835a8e5e9ca39.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse.hip -> fmha_ck_autogen_b41ea5293bc1c56efa2c4b5681d965aa6f2ce6c3.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_813e60e8405aca3f7fbed19452ae37574ada9a77.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_dropout.hip -> fmha_ck_autogen_0ebacd06455ab20eba78b389462946716b5819f6.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_lse.hip -> fmha_ck_autogen_15b255dde1a9d915e582ee2a83de7d83190c6a24.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_lse_dropout.hip -> fmha_ck_autogen_7b2d3680c3578c7292349b58843aef7a82e0087d.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_mask.hip -> fmha_ck_autogen_1d21263e16dafe79b9fe2f998847296e575c14e7.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_mask_dropout.hip -> fmha_ck_autogen_2d23a26e0a59a8323dd97632e610d24624143fbe.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_mask_lse.hip -> fmha_ck_autogen_4fa883a36a76edb276a66c5d779294f170d6d4b7.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psddv_mask_lse_dropout.hip -> fmha_ck_autogen_9207a63fc55c411c73e4f93306c5ffed800dd249.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_0a68c2f9a3acdd787b81be455cbc7836c8bfd90c.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_4217a48a1677bd26cd48e512f1fc8830a8a551b8.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_2b0bcb241e5a1be1d35366461408d06e095a26ef.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_f3193ea266f3718398bc5622f8bc7042c3527a42.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_bb28a4e95723e3df380f98b5ac107c4df353850b.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_61204f6805d5d830aa6fca2a9b5f238ed63c3a73.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_6649f19deeaea20663bee781af7edced7f7a4fc0.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_d3784fb4c0685d7b651f4113f3c71e050881f3a5.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_ed6bdf67720e938d538a867548ac3579b8238169.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_971a08c2e48d805b295d979b24173a04cf58def0.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_c4997f79435cf64add10506acb97d0647cfbb3d4.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_188a70d526394e254274df95de0727850820326c.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_661b49505cfecbe4ec3e5c7371de3aaaa85ac9d5.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_d63c8c746055851217a514321cd735eaf6937263.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_745705ae121a1a331527cedfe4d31218a428a0df.hip +fmha_fwd_d32_fp16_batch_shb_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_6fa6478cc27e52fd9511fbff38369c921155cfb9.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_4fa4d21931b9afcbd70b1567995d3eeb6f9308aa.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_d43715cce8935439f90172d141050d78c7e76fb7.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_ae1afeb6cfdf860ff08e4c2f11c922fd5bfa621a.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_f24bd5b92ce6bba640b8ec6b4e53fe35902c5572.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_481415463f0316ebe25ff2fda47c68cc54db3359.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_db5016bff9e5dc37184d2b9417eb351c7ea1c322.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_d64b8b52f4a98801e185e2f132b2f80c29dd0c37.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_09ecb6347009f6a5d5530a6acf90f9f40288cbcf.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_50e59bd079f4d205b613056f975fd2b4e372ab10.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_fd10a3b937e9659716925e39a01d794914b08e26.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_ec51d24ab5f24e003ed6751ae8ae5b327892b15a.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_a5f8b7b2a891aa9f2ab49762eb31d835efdf18b6.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_9a0a70932bd587759df1e5e150b25b0126d7b529.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_9d3d274058bc0a3d4d35d90669587761fdfbdba1.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_23914c00690ac5c4f89cdbbaf00732ba66c5c0ef.hip +fmha_fwd_d32_fp16_group_hbs_b128x64x16x32x32x32_r2x1x1_r2x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_0befed50a89d80c22b2c8c3d5ba67d73c3d0190e.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv.hip -> fmha_ck_autogen_88c04463f9c5ce565a9daa8c22e16de80fadd707.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi.hip -> fmha_ck_autogen_01e8aedb7b7d77f44a46b2e9b7a826f245aaf4a7.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_dropout.hip -> fmha_ck_autogen_beae876d6da465687f162136231f15767cc7bb14.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse.hip -> fmha_ck_autogen_26f90358e522d7bb7c76c3a2c6010f0f38788bb6.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse_dropout.hip -> fmha_ck_autogen_d7bda8157fb27d544e049fd7d2ec735725f1bf44.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_9fb389d4b5ba590baa951f17da06f0e53d2bfa55.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_dropout.hip -> fmha_ck_autogen_428ce4e14cf94b284ffa735fe03d923cc74c9fe0.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse.hip -> fmha_ck_autogen_900d7f81c73b35ea64095d01c5d48d9190839e0a.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_d2daccc4b3a0f90bff39cb4597f8b7e484613d9e.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_dropout.hip -> fmha_ck_autogen_f280e1639680ac1e5830a21f921bfe2cf364ef42.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_lse.hip -> fmha_ck_autogen_0dde401aa76cb5425563cbbdb0362748148da3ca.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_lse_dropout.hip -> fmha_ck_autogen_dc62a8db637d32e7dfdb2521cbdae6e1fbbd5fd1.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask.hip -> fmha_ck_autogen_4cd3de43cc1f7588d62a10362f59d113ee818846.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_dropout.hip -> fmha_ck_autogen_224f9af5e5ca519b21b71a54acb49f50b4999c47.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_lse.hip -> fmha_ck_autogen_4c8720923c3452e3aebd7b9c1b4b23f0c35d7e4f.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_lse_dropout.hip -> fmha_ck_autogen_2c7aede7762a524a7a424cc4dc46e43fdedf73a2.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_a98925d99dc484da41dd55700e151cf545cf821d.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_1c65ba6dba01da9caa84ba89453b61d81376763f.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_4b76e5dce9af523422782dd25d8dcf6f25edc68f.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_fe245e9ea974adce2b9807d33b9ba12d916eaffb.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_281d897ad17d7f6db2741b396e6b85a9b8f35286.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_31a968898f0bc6366313e41eddb5e3a3ed12dc98.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_52688999141a72e61322140db29043ef9f7fbc3d.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_92b722cdabcfaa388ccc6ccceb7e42462f3bdcd1.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_47f3ced9b5ddb0dfee8ed5e7df8eca0bbe273047.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_d2dfdb42c1b380e860aa5609302f29698dd27923.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_7fe409f4421193fb48a54aa5f26bd6229d23204c.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_b3a104733f678193068d8642d6560faa03897258.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_df66feebc9a0dcc508ce002c255154622875e524.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_8fa4c40e244b412a07933d369704bcdaa6d5e74c.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_60efa9c427dc278c0d1bc31189f683cd45e4d873.hip +fmha_fwd_d64_bf16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_8e50ea8dd480012cbe10be392cd26d1870e6ef9b.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_e5ccd5f7ddc894b2717112cbfc766804e02b7bd1.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_4911bdd71351610d55916d452495e599960d0a41.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_d2d08c5470a385d0160b2c1441fd1c30fff1c17c.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_012c0f480917c329f4c3c6c666cf32af2d82b294.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_0bb81407c8a2b3cdc5fecf655b3ad64d5d729cc9.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_7ff65c7abd9b0d8a2df9302d6dc167637b3a72f0.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_d712f23ef88ae5d7b161d36f42d22a5ba53b6354.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_5bc803342862aa30e23e5be7d84e611bc571c529.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_0ace6e29e1d3060c3086c08fe27b471e375f9c75.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_54ff49018f1c12b9fa31e523ad40b9cc162ba34d.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_351425a006aeeff4d69c8570cb6bf1e1427d2c21.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_fcb6ef39c3db49f26f736d6c9221dd825409ec4e.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_f98a6b193fec3203eaa75819f6b51aa45a48f212.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_2d446754d7000673779d15d3e73039fd3c10a720.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_ca00cfdc5592b7440d72482a18781e9cf3afb05a.hip +fmha_fwd_d64_bf16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_1211733062ed30b876f1d63bffa642d77e258dd6.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv.hip -> fmha_ck_autogen_9b6d08e63b9a90f2524cbfa8c5fcf8b82a1d2d36.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi.hip -> fmha_ck_autogen_e52e3053f30f780f346fa6b7a836ad2554cb85df.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_dropout.hip -> fmha_ck_autogen_3ecf565a5a1c4a09887c67ac3b9a019dca427ac0.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse.hip -> fmha_ck_autogen_52a89981a05963efcea7ba5c1e967638beeebbbb.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_lse_dropout.hip -> fmha_ck_autogen_2173b7c710d418f44dc2b41bec5905024334eae5.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask.hip -> fmha_ck_autogen_b1ad101ce91348266d3885afdf2996a0fdb72135.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_dropout.hip -> fmha_ck_autogen_4da9e9b7277bc90518ab92860bef2097ba96d982.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse.hip -> fmha_ck_autogen_7e1bdde812c332c9fc58613698568a04771b9fa8.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_1acf2f892742b1d236d2b31a8185c6869126adad.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_dropout.hip -> fmha_ck_autogen_155bafb551768855c8c01faa63e44764ebe6c110.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_lse.hip -> fmha_ck_autogen_f053c9c32518b895daaa3521827f37af78836fb8.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_lse_dropout.hip -> fmha_ck_autogen_adf160741a4f751d2f15d6eb23d4121cdca62b55.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask.hip -> fmha_ck_autogen_34c2db98d8e2e690f499f41cfd5afb831b756f54.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_dropout.hip -> fmha_ck_autogen_0789852b0cd3cc030c78b28f2fd5b6b0546382a4.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_lse.hip -> fmha_ck_autogen_532a6ffd8a21d3e98342fd401f0247f62ca4e038.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psddv_mask_lse_dropout.hip -> fmha_ck_autogen_d0daa59f5dce6fc3965193ae37d8c82a3d1834e6.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_4a4a00bd6ea27ff20a2903d619e1361b5e27672a.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_93054acb8a9508fd0f0f486367fb62454de47c39.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_b774450ebadaacf23e944aaf8ca90eada01e8a5a.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_2a833fc01e88bd8e256ef64ae8251dd0ed10720b.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_aa522b43c5e5ea69bcabb4c0fe28def2bd081a12.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_67fb736c61088b8dd92fe0371f5c98e23bf9077f.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_b5c3131fb8e5a25bd4a14bc9075eb6fa01b61d02.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_d7fae2c18645d36a181a0bdd2d8ca7a4ac0f6d1d.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_c355189ade9b1a8269230232db754a3881b53168.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_e035773419a9b3631698a3d375d829af55f7731e.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_d992eab7de49033f5480c5e86a69e675db0d2a19.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_5382a30dcf702daae19bd6705864bfe36e09502c.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_167f5328b035ed59a6f05dfee31edd704c4b07ee.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_c1b94e19d762ddc33cc4e94c6675d93cbde21e3d.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_606f5e0b99814b0a82a731de36f28024bc317801.hip +fmha_fwd_d64_fp16_batch_shb_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_0ad9d68fcee021437e13ffdf94d78252205f5a31.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv.hip -> fmha_ck_autogen_85156f2c556c6ef6180608c361b7b35ede71ffea.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi.hip -> fmha_ck_autogen_890aa875ac13957f00b30210477924697abf0c9e.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_dropout.hip -> fmha_ck_autogen_3108502fd29d3a24b32177bcea968121ee809115.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse.hip -> fmha_ck_autogen_d66b79c4ebdcfd239cecec58203606bc123bd6bb.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_lse_dropout.hip -> fmha_ck_autogen_5efe77ca5c394a60af0313072cdd132216a52bf3.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask.hip -> fmha_ck_autogen_772016803aa3ca6ebe785557118365f9be7c4339.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_dropout.hip -> fmha_ck_autogen_93728d999ae43ee1b5a16e60b90cf8533c7d303f.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse.hip -> fmha_ck_autogen_a1cba1509c413c870c5d784410855ee1bd737da2.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_alibi_mask_lse_dropout.hip -> fmha_ck_autogen_c59ab718fa23f24f09a713ac28a339208a7a5802.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_dropout.hip -> fmha_ck_autogen_afcafd07c1f56e74373ccf37db35976023456d50.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse.hip -> fmha_ck_autogen_ebb9abf5b09e63cbe76390bb46ff7cbefb3141f0.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_lse_dropout.hip -> fmha_ck_autogen_419461cdb5687ebbb7bf0be136071d70420c1619.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask.hip -> fmha_ck_autogen_4beca56234ff6fb4f23b9b24822887fd9a3d0df9.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_dropout.hip -> fmha_ck_autogen_a8a4af070ee46d802cb11086b93daf91538f8a04.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse.hip -> fmha_ck_autogen_79f182ae021e23869d7bebf2a9b4575bdc910ed0.hip +fmha_fwd_d64_fp16_group_hbs_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_qr_async_vr_psskddv_mask_lse_dropout.hip -> fmha_ck_autogen_770ad1eb1b30ad8f1e7c17df486093129b2d5630.hip diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/rename_ck_autogen_files.sh b/aten/src/ATen/native/transformers/hip/flash_attn/ck/rename_ck_autogen_files.sh new file mode 100644 index 000000000000..0dc441e87ec3 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/rename_ck_autogen_files.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -ex + +file_renaming_txt="rename_ck_autogen_files.output.txt" +rm -rf $file_renaming_txt +for file in `ls fmha_*wd*hip`; do + sha1=$(sha1sum $file | cut -d' ' -f1) + new_file="fmha_ck_autogen_${sha1}.hip" + mv $file $new_file + echo "$file -> $new_file" >> $file_renaming_txt +done diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/rotary.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/rotary.hpp new file mode 100644 index 000000000000..85754c037872 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/rotary.hpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +// keep sync with RotaryEmbeddingEnum +enum class rope_enum +{ + none = 0, + interleaved = 1, + half_rotated = 2, +}; + +template +std::tuple, ck_tile::HostTensor> +generate_rotary_cos_sin(ck_tile::index_t seqlen, + ck_tile::index_t rotary_dim, + std::optional seed = std::nullopt) +{ + // return dummy tensors if we won't apply RoPE at all + if(rotary_dim <= 0) + { + ck_tile::HostTensor dummy({1, 1}); + return std::make_tuple(dummy, dummy); + } + + std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}()); + std::uniform_real_distribution generator(0.0f, 1.0f); + + const ck_tile::index_t num_rows = seqlen * 2; + const ck_tile::index_t num_cols = rotary_dim / 2; + + using std::begin, std::end; + + ck_tile::HostTensor angle({num_rows, num_cols}); + std::generate(begin(angle), end(angle), [&] { return generator(random_engine) * 2 * M_PI; }); + + ck_tile::HostTensor cos({num_rows, num_cols}); + std::transform(begin(angle), end(angle), begin(cos), [](float origin_value) { + return ck_tile::type_convert(std::cos(origin_value)); + }); + + ck_tile::HostTensor sin({num_rows, num_cols}); + std::transform(begin(angle), end(angle), begin(sin), [](float origin_value) { + return ck_tile::type_convert(std::sin(origin_value)); + }); + + return std::make_tuple(cos, sin); +} + +template +std::tuple, ck_tile::HostTensor> +slice_rotary_cos_sin(const ck_tile::HostTensor& cos, + const ck_tile::HostTensor& sin, + ck_tile::index_t seqlen_offset, + ck_tile::index_t seqlen) +{ + assert(cos.get_num_of_dimension() == 2 && sin.get_num_of_dimension() == 2); + assert(cos.get_length(0) == sin.get_length(0) && cos.get_length(1) == sin.get_length(1)); + + assert(static_cast(seqlen_offset + seqlen) <= cos.get_length(0)); + + const ck_tile::index_t num_rows = seqlen; + const ck_tile::index_t num_cols = cos.get_length(1); + + ck_tile::HostTensor cos_pt({num_rows, num_cols}); + cos_pt.ForEach([&](auto& self, auto i) { self(i) = cos(i[0] + seqlen_offset, i[1]); }); + + ck_tile::HostTensor sin_pt({num_rows, num_cols}); + sin_pt.ForEach([&](auto& self, auto i) { self(i) = sin(i[0] + seqlen_offset, i[1]); }); + + return std::make_tuple(cos_pt, sin_pt); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h new file mode 100644 index 000000000000..9d4252ad6ed6 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -0,0 +1,503 @@ +#pragma once +#include + +#include +#include +#include + + +namespace pytorch_flash { + +// AOTriton Implementation +TORCH_API +std::tuple +mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &out_, // batch_size x seqlen_q x num_heads x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_); + +std::tuple +mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + std::optional &block_table_, + std::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_); + +std::tuple +mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset); + +std::tuple +mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x s softmax logsumexp + std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset); +#if defined(USE_CK_FLASH_ATTENTION) +// CK implementation +TORCH_API +std::tuple +mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &out_, // batch_size x seqlen_q x num_heads x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_); + +std::tuple +mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + std::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_); + +std::tuple +mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset); + +std::tuple +mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_heads, x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x s softmax logsumexp + std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset); +#endif + +TORCH_API +inline std::tuple +mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &out_, // batch_size x seqlen_q x num_heads x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_) { +#if defined(USE_CK_FLASH_ATTENTION) + if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { + return mha_fwd_ck(q, + k, + v, + out_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); + } else { + return mha_fwd_aot(q, + k, + v, + out_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); + + } +#else + return mha_fwd_aot(q, + k, + v, + out_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); +#endif +} + +inline std::tuple +mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + std::optional &block_table_, // Not used on ROCm. Keeping for parity with CUDA + std::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_) { +#if defined(USE_CK_FLASH_ATTENTION) + if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { + return mha_varlen_fwd_ck(q, + k, + v, + out_, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); + } else { + return mha_varlen_fwd_aot(q, + k, + v, + out_, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + block_table_, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); + } +#else + return mha_varlen_fwd_aot(q, + k, + v, + out_, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + block_table_, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); +#endif + +} + + +inline std::tuple +mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset) { +#if defined(USE_CK_FLASH_ATTENTION) + if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { + return mha_bwd_ck(dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); + } else { + return mha_bwd_aot(dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); + } +#else + return mha_bwd_aot(dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); +#endif + +} + +inline std::tuple +mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x s softmax logsumexp + std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset) { +#if defined(USE_CK_FLASH_ATTENTION) + if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { + return mha_varlen_bwd_ck(dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); + } else { + return mha_varlen_bwd_aot(dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); + } +#else + return mha_varlen_bwd_aot(dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); +#endif +} + +} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_common_hip.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/flash_common_hip.hpp new file mode 100644 index 000000000000..3ad4766f6e19 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_common_hip.hpp @@ -0,0 +1,53 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#ifdef OLD_GENERATOR_PATH +#include +#else +#include +#endif + +#include + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +namespace flash { +inline __global__ void ParsePhiloxCudaState(at::PhiloxCudaState arg, uint64_t* rng_state) +{ + // Imitate from PyTorch + // https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17 + if (arg.captured_) { + rng_state[0] = static_cast(*arg.seed_.ptr); + rng_state[1] = static_cast(*(arg.offset_.ptr) + arg.offset_intragraph_); + } else { + rng_state[0] = arg.seed_.val; + rng_state[1] = arg.offset_.val; + } +} + + +} // namespace flash diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 3825b80bd847..2ecfa5a8197c 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1380,6 +1380,9 @@ if(USE_ROCM) if(USE_MEM_EFF_ATTENTION) target_compile_definitions(torch_hip PRIVATE USE_MEM_EFF_ATTENTION) endif() + if(USE_CK_FLASH_ATTENTION) + target_compile_definitions(torch_hip PRIVATE USE_CK_FLASH_ATTENTION) + endif() endif() if(BUILD_LITE_INTERPRETER) diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 9355e01aad98..b46560e123ba 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -130,6 +130,7 @@ function(caffe2_print_configuration_summary) if(${USE_ROCM}) message(STATUS " ROCM_VERSION : ${ROCM_VERSION}") message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}") + message(STATUS " USE_CK_FLASH_ATTENTION : ${USE_CK_FLASH_ATTENTION}") message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}") endif() message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}") diff --git a/docs/source/backends.rst b/docs/source/backends.rst index 6d3500c85421..de11a3c95748 100644 --- a/docs/source/backends.rst +++ b/docs/source/backends.rst @@ -69,6 +69,8 @@ torch.backends.cuda .. autofunction:: torch.backends.cuda.preferred_blas_library +.. autofunction:: torch.backends.cuda.preferred_rocm_fa_library + .. autofunction:: torch.backends.cuda.preferred_linalg_library .. autoclass:: torch.backends.cuda.SDPAParams diff --git a/test/test_transformers.py b/test/test_transformers.py index e291a6c7956b..715fbe4297bd 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -3317,6 +3317,10 @@ class TestSDPACudaOnly(NNTestCase): if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30: unittest.skip("Reference implementation OOM") return + if TEST_WITH_ROCM and dropout_p != 0: + self.skipTest("CK does not support tensor format dropout masks") + if TEST_WITH_ROCM and head_dim > 128: + self.skipTest("CK does not support head dims over 128") scale = scale if scale is None else (1 / head_dim) num_heads_q = num_heads_kv = 4 diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index 26af7eec1fbe..f2ae6200d2c3 100755 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -101,7 +101,6 @@ includes = [ "aten/src/ATen/native/transformers/cuda/mem_eff_attention/debug_utils.h", "aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h", "aten/src/ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h", - "aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h", "aten/src/THC/*", "aten/src/ATen/test/*", # CMakeLists.txt isn't processed by default, but there are a few diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 2d7d5bd50e68..31a00510c633 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1288,6 +1288,14 @@ class _BlasBackend: Cublaslt: _BlasBackend Ck: _BlasBackend +def _get_rocm_fa_preferred_backend() -> torch._C._ROCmFABackend: ... +def _set_rocm_fa_preferred_backend(arg: torch._C._ROCmFABackend): ... + +class _ROCmFABackend: + Default: _ROCmFABackend + AOTriton: _ROCmFABackend + Ck: _ROCmFABackend + class ConvBackend(Enum): ... class Tag(Enum): diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index d5b992add754..53defbd20fa4 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -616,6 +616,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys( "torch._C._get_function_stack_at", "torch._C._get_graph_executor_optimize", "torch._C._get_linalg_preferred_backend", + "torch._C._get_rocm_fa_preferred_backend", "torch._C._get_math_sdp_enabled", "torch._C._get_math_sdp_allow_fp16_bf16_reduction", "torch._C._get_max_operator_version", @@ -1144,6 +1145,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys( "torch._C._set_grad_enabled", "torch._C._set_graph_executor_optimize", "torch._C._set_linalg_preferred_backend", + "torch._C._set_rocm_fa_preferred_backend", "torch._C._set_meta_in_tls_dispatch_include", "torch._C._set_mkldnn_enabled", "torch._C._set_multithreading_enabled", @@ -2424,6 +2426,7 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys( "torch.backends.cuda.enable_cudnn_sdp", "torch.backends.cuda.preferred_blas_library", "torch.backends.cuda.preferred_linalg_library", + "torch.backends.cuda.preferred_rocm_fa_library", "torch.backends.cuda.sdp_kernel", "torch.backends.cudnn._init", "torch.backends.cudnn.flags", diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index 2b7aa4494667..b305819c1b05 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -14,6 +14,7 @@ __all__ = [ "cuBLASModule", "preferred_linalg_library", "preferred_blas_library", + "preferred_rocm_fa_library", "cufft_plan_cache", "matmul", "SDPAParams", @@ -264,9 +265,57 @@ def preferred_blas_library( return torch._C._get_blas_preferred_backend() +_ROCmFABackends = { + "default": torch._C._ROCmFABackend.Default, + "aotriton": torch._C._ROCmFABackend.AOTriton, + "ck": torch._C._ROCmFABackend.Ck, +} +_ROCmFABackends_str = ", ".join(_ROCmFABackends.keys()) + + from torch._C import _SDPAParams as SDPAParams, _SDPBackend as SDPBackend +def preferred_rocm_fa_library( + backend: Union[None, str, torch._C._ROCmFABackend] = None +) -> torch._C._ROCmFABackend: + r""" + [ROCm-only] + Override the backend PyTorch uses in ROCm environments for Flash Attention. Choose between AOTriton and CK + + .. warning:: This flag is experimeental and subject to change. + + When Flash Attention is enabled and desired, PyTorch defaults to using AOTriton as the backend. + This flag (a :class:`str`) allows users to override this backend to use composable_kernel + + * If `"default"` is set then the default backend will be used wherever possible. Currently AOTriton. + * If `"aotriton"` is set then AOTriton will be used wherever possible. + * If `"ck"` is set then CK will be used wherever possible. + * When no input is given, this function returns the currently preferred library. + * User may use the environment variable TORCH_ROCM_FA_PREFER_CK=1 to set the preferred library to CK + globally. + + Note: When a library is preferred other libraries may still be used if the preferred library + doesn't implement the operation(s) called. + This flag may achieve better performance if PyTorch's library selection is incorrect + for your application's inputs. + """ + if backend is None: + pass + elif isinstance(backend, str): + if backend not in _ROCmFABackends: + raise RuntimeError( + "Unknown input value. " f"Choose from: {_ROCmFABackends_str}." + ) + torch._C._set_rocm_fa_preferred_backend(_ROCmFABackends[backend]) + elif isinstance(backend, torch._C._ROCmFABackend): + torch._C._set_rocm_fa_preferred_backend(backend) + else: + raise ValueError("Unknown input value. " f"Choose from: {_ROCmFABackends_str}.") + + return torch._C._get_rocm_fa_preferred_backend() + + # Set the __module__ attribute SDPAParams.__module__ = "torch.backends.cuda" SDPAParams.__name__ = "SDPAParams" diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 044199db29b3..2230b15aeb3a 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -15,6 +15,7 @@ #include #include #include + #include #include #include @@ -108,6 +109,7 @@ #include #ifdef USE_CUDA +#include #include #include #ifdef __HIP_PLATFORM_AMD__ @@ -2192,6 +2194,18 @@ Call this whenever a new thread is created in order to propagate values from return at::globalContext().blasPreferredBackend(); }); + py::enum_(py_module, "_ROCmFABackend") + .value("Default", at::ROCmFABackend::Default) + .value("AOTriton", at::ROCmFABackend::AOTriton) + .value("Ck", at::ROCmFABackend::Ck); + + py_module.def("_set_rocm_fa_preferred_backend", [](at::ROCmFABackend b) { + at::globalContext().setROCmFAPreferredBackend(b); + }); + py_module.def("_get_rocm_fa_preferred_backend", []() { + return at::globalContext().getROCmFAPreferredBackend(); + }); + py_module.def( "_construct_storage_from_data_pointer", [](int64_t data_ptr, c10::Device device, size_t size_bytes) {